From ce047f93d0f3f19b55b15d6da4231cdef6b42f8c Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Thu, 10 Sep 2020 13:34:28 +0100 Subject: [PATCH 1/2] Add kwarg support to at-nondifferentiable --- Project.toml | 2 +- src/rule_definition_tools.jl | 28 ++++++++++++++-------------- test/rule_definition_tools.jl | 26 ++++++++++++++++++++++++++ 3 files changed, 41 insertions(+), 15 deletions(-) diff --git a/Project.toml b/Project.toml index 9fe1a6192..e6de77d99 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRulesCore" uuid = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -version = "0.9.8" +version = "0.9.9" [deps] MuladdMacro = "46d2c3a1-f734-5fdb-9937-b9b9aeba4221" diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index b686ffc0d..320030eb9 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -295,7 +295,8 @@ macro non_differentiable(sig_expr) primal_sig_parts = [:(::typeof($primal_name)), constrained_args...] unconstrained_args = _unconstrain.(constrained_args) - primal_invoke = Expr(:call, esc(primal_name), esc.(unconstrained_args)...) + + primal_invoke = :($(primal_name)($(unconstrained_args...); kwargs...)) quote $(_nondiff_frule_expr(primal_sig_parts, primal_invoke)) @@ -304,12 +305,12 @@ macro non_differentiable(sig_expr) end function _nondiff_frule_expr(primal_sig_parts, primal_invoke) - return Expr( - :(=), - Expr(:call, :(ChainRulesCore.frule), esc(:_), esc.(primal_sig_parts)...), - # Julia functions always only have 1 output, so just return a single DoesNotExist() - Expr(:tuple, primal_invoke, DoesNotExist()), - ) + return esc(:( + function ChainRulesCore.frule($(gensym(:_)), $(primal_sig_parts...); kwargs...) + # Julia functions always only have 1 output, so return a single DoesNotExist() + return ($primal_invoke, DoesNotExist()) + end + )) end function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) @@ -317,15 +318,14 @@ function _nondiff_rrule_expr(primal_sig_parts, primal_invoke) primal_name = first(primal_invoke.args) pullback_expr = Expr( :function, - Expr(:call, esc(propagator_name(primal_name, :pullback)), esc(:_)), + Expr(:call, propagator_name(primal_name, :pullback), :_), Expr(:tuple, NO_FIELDS, ntuple(_->DoesNotExist(), num_primal_inputs)...) ) - rrule_defn = Expr( - :(=), - Expr(:call, :(ChainRulesCore.rrule), esc.(primal_sig_parts)...), - Expr(:tuple, primal_invoke, pullback_expr), - ) - return rrule_defn + return esc(:( + function ChainRulesCore.rrule($(primal_sig_parts...); kwargs...) + return ($primal_invoke, $pullback_expr) + end + )) end diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index 7a05cbfec..d37e252d4 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -73,6 +73,32 @@ end @test rrule(pointy_identity, 2.0) == nothing end + @testset "kwargs" begin + kw_demo(x; kw=2.0) = x + kw + @non_differentiable kw_demo(::Any) + + @testset "not setting kw" begin + @assert kw_demo(1.5) == 3.5 + + res, pullback = rrule(kw_demo, 1.5) + @test res == 3.5 + @test pullback(4.1) == (NO_FIELDS, DoesNotExist()) + + @test frule((Zero(), 11.1), kw_demo, 1.5) == (3.5, DoesNotExist()) + end + + @testset "setting kw" begin + @assert kw_demo(1.5; kw=3.0) == 4.5 + + res, pullback = rrule(kw_demo, 1.5; kw=3.0) + @test res == 4.5 + @test pullback(1.1) == (NO_FIELDS, DoesNotExist()) + + @test frule((Zero(), 11.1), kw_demo, 1.5; kw=3.0) == (4.5, DoesNotExist()) + end + + end + @testset "Not supported (Yet)" begin # Varargs are not supported @test_macro_throws ErrorException @non_differentiable vararg1(xs...) From 185e5186383e04af3d44f037d981f57b66fc456f Mon Sep 17 00:00:00 2001 From: Lyndon White Date: Fri, 11 Sep 2020 19:09:28 +0100 Subject: [PATCH 2/2] remvoe excess whitespace Co-authored-by: mattBrzezinski --- test/rule_definition_tools.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index d37e252d4..fc979f049 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -96,7 +96,6 @@ end @test frule((Zero(), 11.1), kw_demo, 1.5; kw=3.0) == (4.5, DoesNotExist()) end - end @testset "Not supported (Yet)" begin