diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index 10ce7beec..0bc8f1e5b 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -409,7 +409,7 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) return ($(esc(_with_kwargs_expr(primal_invoke, kwargs))), NoTangent()) end function ChainRulesCore.frule( - @nospecialize(::Tuple), $(map(esc, primal_sig_parts)...) + @nospecialize(::Any), $(map(esc, primal_sig_parts)...) ) $(__source__) # Julia functions always only have 1 output, so return a single NoTangent()