Skip to content

Commit

Permalink
Update kwargs fallback rules after RuleConfig rewrite
Browse files Browse the repository at this point in the history
  • Loading branch information
Keno committed Jun 18, 2021
1 parent bf01ddf commit 7c40e72
Show file tree
Hide file tree
Showing 3 changed files with 45 additions and 21 deletions.
12 changes: 7 additions & 5 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -342,14 +342,16 @@ end

function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
@gensym kwargs
# `::Any` instead of `_`: https://github.com/JuliaLang/julia/issues/32727
return @strip_linenos quote
function ChainRulesCore.frule(
@nospecialize(::Any), $(map(esc, primal_sig_parts)...); $(esc(kwargs))...
)
# Manually defined kw version to save compiler work. See explanation in rules.jl
function (::Core.kwftype(typeof(ChainRulesCore.frule)))(@nospecialize($kwargs::Any),
frule::typeof(ChainRulesCore.frule), @nospecialize(::Any), $(map(esc, primal_sig_parts)...))
return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent())
end
function ChainRulesCore.frule(@nospecialize(::Any), $(map(esc, primal_sig_parts)...))
$(__source__)
# Julia functions always only have 1 output, so return a single NoTangent()
return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent())
return ($(esc(primal_invoke)), NoTangent())
end
end
end
Expand Down
38 changes: 22 additions & 16 deletions src/rules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,10 +58,24 @@ will be hit as a fallback. This is the case for most rules.
See also: [`rrule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref)
"""
frule(::Any, ::Any, ::Vararg{Any}; kwargs...) = nothing
frule(ȧrgs, f, ::Vararg{Any}) = nothing

# if no config is present then fallback to config-less rules
frule(::RuleConfig, ȧrgs, f, args...; kwargs...) = frule(ȧrgs, f, args...; kwargs...)
frule(::RuleConfig, args...) = frule(args...)

# Manual fallback for keyword arguments. Usually this would be generated by
#
# frule(::Any, ::Vararg{Any}; kwargs...) = nothing
#
# However - the fallback method is so hot that we want to avoid any extra code
# that would be required to have the automatically generated method package up
# the keyword arguments (which the optimizer will throw away, but the compiler
# still has to manually analyze). Manually declare this method with an
# explicitly empty body to save the compiler that work.
const frule_kwfunc = Core.kwftype(typeof(frule)).instance
(::typeof(frule_kwfunc))(::Any, ::typeof(frule), ȧrgs, f, ::Vararg{Any}) = nothing
(::typeof(frule_kwfunc))(kws::Any, ::typeof(frule), ::RuleConfig, args...) =
(frule_kwfunc)(kws, frule, args...)

"""
rrule([::RuleConfig,] f, x...)
Expand Down Expand Up @@ -116,18 +130,10 @@ See also: [`frule`](@ref), [`@scalar_rule`](@ref), [`RuleConfig`](@ref)
rrule(::Any, ::Vararg{Any}) = nothing

# if no config is present then fallback to config-less rules
rrule(::RuleConfig, f, args...; kwargs...) = rrule(f, args...; kwargs...)
# TODO do we need to do something for kwargs special here for performance?
# See: https://github.com/JuliaDiff/ChainRulesCore.jl/issues/368

# Manual fallback for keyword arguments. Usually this would be generated by
#
# rrule(::Any, ::Vararg{Any}; kwargs...) = nothing
#
# However - the fallback method is so hot that we want to avoid any extra code
# that would be required to have the automatically generated method package up
# the keyword arguments (which the optimizer will throw away, but the compiler
# still has to manually analyze). Manually declare this method with an
# explicitly empty body to save the compiler that work.
rrule(::RuleConfig, args...) = rrule(args...)

(::Core.kwftype(typeof(rrule)))(::Any, ::Any, ::Vararg{Any}) = nothing
# Manual fallback for keyword arguments. See above
const rrule_kwfunc = Core.kwftype(typeof(rrule)).instance
(::typeof(rrule_kwfunc))(::Any, ::typeof(rrule), ::Any, ::Vararg{Any}) = nothing
(::typeof(rrule_kwfunc))(kws::Any, ::typeof(rrule), ::RuleConfig, args...) =
(rrule_kwfunc)(kws, rrule, args...)
16 changes: 16 additions & 0 deletions test/config.jl
Original file line number Diff line number Diff line change
Expand Up @@ -161,3 +161,19 @@ end
@test (MostBoringConfig() .=> (1,2,3)) isa NTuple{3, Pair{MostBoringConfig,Int}}
end
end

@testset "fallbacks" begin
# Test that incorrect use of the fallback rules correctly throws MethodError
@test_throws MethodError frule()
@test_throws MethodError frule(;kw="hello")
@test_throws MethodError frule(sin)
@test_throws MethodError frule(sin;kw="hello")
@test_throws MethodError frule(MostBoringConfig())
@test_throws MethodError frule(MostBoringConfig(); kw="hello")
@test_throws MethodError frule(MostBoringConfig(), sin)
@test_throws MethodError frule(MostBoringConfig(), sin; kw="hello")
@test_throws MethodError rrule()
@test_throws MethodError rrule(;kw="hello")
@test_throws MethodError rrule(MostBoringConfig())
@test_throws MethodError rrule(MostBoringConfig();kw="hello")
end

0 comments on commit 7c40e72

Please sign in to comment.