From 7c40e726c91492417f1074232620fbb02667d31c Mon Sep 17 00:00:00 2001 From: Keno Fischer Date: Thu, 17 Jun 2021 21:49:39 -0400 Subject: [PATCH] Update kwargs fallback rules after RuleConfig rewrite Fixes #368 --- src/rule_definition_tools.jl | 12 +++++++----- src/rules.jl | 38 +++++++++++++++++++++--------------- test/config.jl | 16 +++++++++++++++ 3 files changed, 45 insertions(+), 21 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 57d3c47ca..a1d67490d 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -342,14 +342,16 @@ end function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) @gensym kwargs - # `::Any` instead of `_`: https://github.com/JuliaLang/julia/issues/32727 return @strip_linenos quote - function ChainRulesCore.frule( - @nospecialize(::Any), $(map(esc, primal_sig_parts)...); $(esc(kwargs))... - ) + # Manually defined kw version to save compiler work. See explanation in rules.jl + function (::Core.kwftype(typeof(ChainRulesCore.frule)))(@nospecialize($kwargs::Any), + frule::typeof(ChainRulesCore.frule), @nospecialize(::Any), $(map(esc, primal_sig_parts)...)) + return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) + end + function ChainRulesCore.frule(@nospecialize(::Any), $(map(esc, primal_sig_parts)...)) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent() - return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) + return ($(esc(primal_invoke)), NoTangent()) end end end diff --git a/src/rules.jl b/src/rules.jl index 232b1c558..b0b830fb7 100644 --- a/src/rules.jl +++ b/src/rules.jl @@ -58,10 +58,24 @@ will be hit as a fallback. This is the case for most rules. See also: [`rrule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref) """ -frule(::Any, ::Any, ::Vararg{Any}; kwargs...) = nothing +frule(ȧrgs, f, ::Vararg{Any}) = nothing # if no config is present then fallback to config-less rules -frule(::RuleConfig, ȧrgs, f, args...; kwargs...) = frule(ȧrgs, f, args...; kwargs...) +frule(::RuleConfig, args...) = frule(args...) + +# Manual fallback for keyword arguments. Usually this would be generated by +# +# frule(::Any, ::Vararg{Any}; kwargs...) = nothing +# +# However - the fallback method is so hot that we want to avoid any extra code +# that would be required to have the automatically generated method package up +# the keyword arguments (which the optimizer will throw away, but the compiler +# still has to manually analyze). Manually declare this method with an +# explicitly empty body to save the compiler that work. +const frule_kwfunc = Core.kwftype(typeof(frule)).instance +(::typeof(frule_kwfunc))(::Any, ::typeof(frule), ȧrgs, f, ::Vararg{Any}) = nothing +(::typeof(frule_kwfunc))(kws::Any, ::typeof(frule), ::RuleConfig, args...) = + (frule_kwfunc)(kws, frule, args...) """ rrule([::RuleConfig,] f, x...) @@ -116,18 +130,10 @@ See also: [`frule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref) rrule(::Any, ::Vararg{Any}) = nothing # if no config is present then fallback to config-less rules -rrule(::RuleConfig, f, args...; kwargs...) = rrule(f, args...; kwargs...) -# TODO do we need to do something for kwargs special here for performance? -# See: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/368 - -# Manual fallback for keyword arguments. Usually this would be generated by -# -# rrule(::Any, ::Vararg{Any}; kwargs...) = nothing -# -# However - the fallback method is so hot that we want to avoid any extra code -# that would be required to have the automatically generated method package up -# the keyword arguments (which the optimizer will throw away, but the compiler -# still has to manually analyze). Manually declare this method with an -# explicitly empty body to save the compiler that work. +rrule(::RuleConfig, args...) = rrule(args...) -(::Core.kwftype(typeof(rrule)))(::Any, ::Any, ::Vararg{Any}) = nothing +# Manual fallback for keyword arguments. See above +const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance +(::typeof(rrule_kwfunc))(::Any, ::typeof(rrule), ::Any, ::Vararg{Any}) = nothing +(::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...) = + (rrule_kwfunc)(kws, rrule, args...) diff --git a/test/config.jl b/test/config.jl index c55c0b6c8..d8465db97 100644 --- a/test/config.jl +++ b/test/config.jl @@ -161,3 +161,19 @@ end @test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}} end end + +@testset "fallbacks" begin + # Test that incorrect use of the fallback rules correctly throws MethodError + @test_throws MethodError frule() + @test_throws MethodError frule(;kw="hello") + @test_throws MethodError frule(sin) + @test_throws MethodError frule(sin;kw="hello") + @test_throws MethodError frule(MostBoringConfig()) + @test_throws MethodError frule(MostBoringConfig(); kw="hello") + @test_throws MethodError frule(MostBoringConfig(), sin) + @test_throws MethodError frule(MostBoringConfig(), sin; kw="hello") + @test_throws MethodError rrule() + @test_throws MethodError rrule(;kw="hello") + @test_throws MethodError rrule(MostBoringConfig()) + @test_throws MethodError rrule(MostBoringConfig();kw="hello") +end