From 82d6dc967891a17800fc8c244410a94ff50ddc84 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 15 Jan 2024 22:41:38 +0100 Subject: [PATCH 1/3] make `@scalar_rule` use `::Tuple` --- src/rule_definition_tools.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 10ce7beec..a52cd7a86 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -209,7 +209,7 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) return @strip_linenos quote # _ is the input derivative w.r.t. function internals. since we do not # allow closures/functors with @scalar_rule, it is always ignored - function ChainRulesCore.frule((_, $(Δs...)), ::Core.Typeof($f), $(inputs...)) + function ChainRulesCore.frule((_, $(Δs...))::Tuple, ::Core.Typeof($f), $(inputs...)) $(__source__) $(esc(:Ω)) = $call $(setup_stmts...) From ee4e1c81a60281a531f626d6192f03e37abf0e37 Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Mon, 15 Jan 2024 22:48:51 +0100 Subject: [PATCH 2/3] try duplicating to avoid new ambiguities from arising --- src/rule_definition_tools.jl | 9 +++++++++ 1 file changed, 9 insertions(+) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index a52cd7a86..e03f7e378 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -218,6 +218,15 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) ) return $(esc(:Ω)), $pushforward_returns end + function ChainRulesCore.frule((_, $(Δs...)), ::Core.Typeof($f), $(inputs...)) + $(__source__) + $(esc(:Ω)) = $call + $(setup_stmts...) + $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output( + $(esc(:Ω)), $f, $(inputs...) + ) + return $(esc(:Ω)), $pushforward_returns + end end end From 52c8bc60ccc0fa308ab49e37f8159a4f3ffe702d Mon Sep 17 00:00:00 2001 From: Mason Protter Date: Tue, 16 Jan 2024 10:03:26 +0100 Subject: [PATCH 3/3] try modifying `@non_differentiable` instead --- src/rule_definition_tools.jl | 11 +---------- 1 file changed, 1 insertion(+), 10 deletions(-) diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index e03f7e378..0bc8f1e5b 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -209,15 +209,6 @@ function scalar_frule_expr(__source__, f, call, setup_stmts, inputs, partials) return @strip_linenos quote # _ is the input derivative w.r.t. function internals. since we do not # allow closures/functors with @scalar_rule, it is always ignored - function ChainRulesCore.frule((_, $(Δs...))::Tuple, ::Core.Typeof($f), $(inputs...)) - $(__source__) - $(esc(:Ω)) = $call - $(setup_stmts...) - $(Expr(:tuple, partials...)) = ChainRulesCore.derivatives_given_output( - $(esc(:Ω)), $f, $(inputs...) - ) - return $(esc(:Ω)), $pushforward_returns - end function ChainRulesCore.frule((_, $(Δs...)), ::Core.Typeof($f), $(inputs...)) $(__source__) $(esc(:Ω)) = $call @@ -418,7 +409,7 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end function ChainRulesCore.frule( - @nospecialize(::Tuple), $(map(esc, primal_sig_parts)...) + @nospecialize(::Any), $(map(esc, primal_sig_parts)...) ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent()