From 2ead553bcb0480c137e5d7168518546cea716428 Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Sun, 7 Mar 2021 18:33:39 -0500 Subject: [PATCH 1/2] Avoid unnecessary keyword arguments check in fallback rrule The expansion `rrule(::Any, ::Vararg{Any}; kwargs...)` actually generates two methods. One along the lines of: ``` (::typeof(Core.kwfunc(rrule))(kwargs, ::typeof(rrule), ::Any, ::Vararg{Any}) = nothing ``` and the other that just calls it: ``` rrule(a::Any, b::Vararg{Any}) = Core.kwfunc(rrule)(NamedTuple{}(), a, b...) ``` The compiler handles this fallback well, since it's used all over the place, but the cost to infer it is non-zero. Of course, in the AD use case, this fallback method is visited literally on every call, so saving a tiny amount of inference/compile time actually leads to noticable improvements over a whole AD problem. --- src/rule_definition_tools.jl | 28 +++++++++++++++++++++------- src/rules.jl | 14 +++++++++++++- 2 files changed, 34 insertions(+), 8 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index acbd83e54..c36c39798 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -1,3 +1,5 @@ +using Base.Meta + # These are some macros (and supporting functions) to make it easier to define rules. """ @scalar_rule(f(x₁, x₂, ...), @@ -198,7 +200,7 @@ function scalar_rrule_expr(f, call, setup_stmts, inputs, partials) end end -# For context on why this is important, see +# For context on why this is important, see # https://github.com/JuliaDiff/ChainRulesCore.jl/pull/276 "Declares properly hygenic inputs for propagation expressions" _propagator_inputs(n) = [esc(gensym(Symbol(:Δ, i))) for i in 1:n] @@ -307,11 +309,11 @@ macro non_differentiable(sig_expr) unconstrained_args = _unconstrain.(constrained_args) primal_invoke = if !has_vararg - :($(primal_name)($(unconstrained_args...); kwargs...)) + :($(primal_name)($(unconstrained_args...))) else normal_args = unconstrained_args[1:end-1] var_arg = unconstrained_args[end] - :($(primal_name)($(normal_args...), $(var_arg)...; kwargs...)) + :($(primal_name)($(normal_args...), $(var_arg)...)) end quote @@ -320,11 +322,18 @@ macro non_differentiable(sig_expr) end end +"changes `f(x,y)` into `f(x,y; kwargs....)`" +function _with_kwargs_expr(call_expr::Expr) + @assert isexpr(call_expr, :call) + Expr(:call, call_expr.args[1], Expr(:parameters, :(kwargs...)), + call_expr.args[2:end]...) +end + function _nondiff_frule_expr(primal_sig_parts, primal_invoke) return esc(:( function ChainRulesCore.frule($(gensym(:_)), $(primal_sig_parts...); kwargs...) # Julia functions always only have 1 output, so return a single DoesNotExist() - return ($primal_invoke, DoesNotExist()) + return ($(_with_kwargs_expr(primal_invoke)), DoesNotExist()) end )) end @@ -349,11 +358,16 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) Expr(:call, propagator_name(primal_name, :pullback), :_), Expr(:tuple, DoesNotExist(), Expr(:(...), tup_expr)) ) - return esc(:( - function ChainRulesCore.rrule($(primal_sig_parts...); kwargs...) + return esc(quote + # Manully defined kw version to save compiler work. + # See rules.jl + function (::Core.kwftype(typeof(rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...)) + return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr) + end + function ChainRulesCore.rrule($(primal_sig_parts...)) return ($primal_invoke, $pullback_expr) end - )) + end) end diff --git a/src/rules.jl b/src/rules.jl index 35967d27b..f0609894c 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -103,4 +103,16 @@ true See also: [`frule`](@ref), [`@scalar_rule`](@ref) """ -rrule(::Any, ::Vararg{Any}; kwargs...) = nothing +rrule(::Any, ::Vararg{Any}) = nothing + +# Manual fallback for keyword arguments. Usually this would be generated by +# +# rrule(::Any, ::Vararg{Any}; kwargs...) = nothing +# +# However - the fallback method is so hot that we want to avoid any extra code +# that would be required to have the automatically generated method package up +# the keyword arguments (which the optimizer will throw away, but the compiler +# still has to manually analyze). Manually declare this method with an +# explicitly empty body to save the compiler that work. + +(::Core.kwftype(typeof(rrule)))(::Any, ::Any, ::Vararg{Any}) = nothing From b43c8a644f9c76ba3a3110dc78e20f3dee44c976 Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Wed, 24 Mar 2021 11:32:15 +0000 Subject: [PATCH 2/2] style tweaks --- src/rule_definition_tools.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index c36c39798..4b4c49d00 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -325,8 +325,9 @@ end "changes `f(x,y)` into `f(x,y; kwargs....)`" function _with_kwargs_expr(call_expr::Expr) @assert isexpr(call_expr, :call) - Expr(:call, call_expr.args[1], Expr(:parameters, :(kwargs...)), - call_expr.args[2:end]...) + return Expr( + :call, call_expr.args[1], Expr(:parameters, :(kwargs...)), call_expr.args[2:end]... + ) end function _nondiff_frule_expr(primal_sig_parts, primal_invoke) @@ -359,8 +360,7 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) Expr(:tuple, DoesNotExist(), Expr(:(...), tup_expr)) ) return esc(quote - # Manully defined kw version to save compiler work. - # See rules.jl + # Manually defined kw version to save compiler work. See explanation in rules.jl function (::Core.kwftype(typeof(rrule)))(kwargs::Any, rrule::typeof(ChainRulesCore.rrule), $(primal_sig_parts...)) return ($(_with_kwargs_expr(primal_invoke)), $pullback_expr) end