Skip to content

Commit

Permalink
make @non_differentiable use identical pullbacks when possible
Browse files Browse the repository at this point in the history
Fixes #678
  • Loading branch information
nsajko committed May 30, 2024
1 parent fa530b9 commit ae7b114
Show file tree
Hide file tree
Showing 2 changed files with 53 additions and 7 deletions.
19 changes: 12 additions & 7 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,27 +418,32 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
end
end

function tuple_expression(primal_sig_parts)
function _make_pullback_for_non_differentiable(::Val{N}) where {N}
Vararg{Any,N} # throw early for invalid `N`, must be nonnegative `Int`
function pullback_for_non_differentiable(::Any)
ntuple(Returns(NoTangent()), Val(N))
end
end

function tuple_length_expression(primal_sig_parts)
has_vararg = _isvararg(primal_sig_parts[end])
return if !has_vararg
num_primal_inputs = length(primal_sig_parts)
Expr(:tuple, ntuple(_ -> NoTangent(), num_primal_inputs)...)
:($num_primal_inputs)
else
num_primal_inputs = length(primal_sig_parts) - 1 # - vararg
length_expr =
:($num_primal_inputs + length($(esc(_unconstrain(primal_sig_parts[end])))))
@strip_linenos :(ntuple(i -> NoTangent(), $length_expr))
@strip_linenos :($length_expr)
end
end

function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
esc_primal_sig_parts = map(esc, primal_sig_parts)
tup_expr = tuple_expression(primal_sig_parts)
tup_len_expr = tuple_length_expression(primal_sig_parts)
primal_name = first(primal_invoke.args)
pullback_expr = @strip_linenos quote
function $(esc(propagator_name(primal_name, :pullback)))(@nospecialize(_))
return $(tup_expr)
end
_make_pullback_for_non_differentiable(Val{$(tup_len_expr)}())
end

@gensym kwargs
Expand Down
41 changes: 41 additions & 0 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,47 @@ end

@testset "rule_definition_tools.jl" begin
@testset "@non_differentiable" begin
@testset "`_make_pullback_for_non_differentiable`" begin
f = ChainRulesCore._make_pullback_for_non_differentiable
@testset "throws on invalid input" begin
@test_throws Exception f(Val(0.0))
@test_throws Exception f(Val(-1))
end
@testset "identical objects" begin
for i 0:5
v = Val(i)
@test f(v) === f(v)
end
end
@testset "correctness" begin
for i 0:5
expected = ntuple((_ -> NoTangent()), i)
@test f(Val(i))(:arbitrary) === expected
end
end
@testset "dispatch" begin
for i 0:5
pullback = f(Val(i))
@test_throws MethodError pullback()
@test_throws MethodError pullback(1, 2)
end
end
end

@testset "issue #678: identical pullback objects" begin
issue_678_f(::Any) = nothing
issue_678_g(::Any) = nothing
issue_678_h(::Any...) = nothing
@non_differentiable issue_678_f(::Any)
@non_differentiable issue_678_g(::Any)
@non_differentiable issue_678_h(::Any...)
@test (
last(rrule(issue_678_f, 0.1)) ===
last(rrule(issue_678_g, 0.2)) ===
last(rrule(issue_678_h, 0.3))
)
end

@testset "two input one output function" begin
nondiff_2_1(x, y) = fill(7.5, 100)[x + y]
@non_differentiable nondiff_2_1(::Any, ::Any)
Expand Down

0 comments on commit ae7b114

Please sign in to comment.