Skip to content

Commit

Permalink
simpler
Browse files Browse the repository at this point in the history
  • Loading branch information
nsajko committed May 31, 2024
1 parent 191eb47 commit 122d83d
Show file tree
Hide file tree
Showing 3 changed files with 2 additions and 22 deletions.
2 changes: 1 addition & 1 deletion src/ChainRulesCore.jl
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ module ChainRulesCore
using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize!
using Base.Meta
using LinearAlgebra
using Compat: hasfield, hasproperty, ismutabletype
using Compat: hasfield, hasproperty, ismutabletype, Returns

export frule, rrule # core function
# rule configurations
Expand Down
10 changes: 1 addition & 9 deletions src/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -418,14 +418,6 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke)
end
end

struct NonDiffPullback{T<:Tuple{Vararg{NoTangent}}} <: Function
v::T
end

function (@nospecialize pb::NonDiffPullback)(@nospecialize ::Any)
return pb.v
end

function tuple_expression(primal_sig_parts)
has_vararg = _isvararg(primal_sig_parts[end])
return if !has_vararg
Expand All @@ -444,7 +436,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke)
tup_expr = tuple_expression(primal_sig_parts)
primal_name = first(primal_invoke.args)
pullback_expr = @strip_linenos quote
NonDiffPullback($(tup_expr))
Returns($(tup_expr))
end

@gensym kwargs
Expand Down
12 changes: 0 additions & 12 deletions test/rule_definition_tools.jl
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,6 @@ end

@testset "rule_definition_tools.jl" begin
@testset "@non_differentiable" begin
@testset "`NonDiffPullback`" begin
NDP = ChainRulesCore.NonDiffPullback
for i in 0:5
tup = ntuple((_ -> NoTangent()), i)
ndp = NDP(tup)
@test ndp === @inferred NDP(tup)
@test tup === @inferred ndp(:arbitrary)
@test_throws MethodError ndp()
@test_throws MethodError ndp(1, 2)
end
end

@testset "issue #678: identical pullback objects" begin
issue_678_f(::Any) = nothing
issue_678_g(::Any) = nothing
Expand Down

0 comments on commit 122d83d

Please sign in to comment.