From 122d83db65ad6907cb54ddc292f903b2ffeaaf1c Mon Sep 17 00:00:00 2001 From: Neven Sajko Date: Fri, 31 May 2024 13:25:52 +0200 Subject: [PATCH] simpler --- src/ChainRulesCore.jl | 2 +- src/rule_definition_tools.jl | 10 +--------- test/rule_definition_tools.jl | 12 ------------ 3 files changed, 2 insertions(+), 22 deletions(-) diff --git a/src/ChainRulesCore.jl b/src/ChainRulesCore.jl index 286f71db2..544b3d9c6 100644 --- a/src/ChainRulesCore.jl +++ b/src/ChainRulesCore.jl @@ -2,7 +2,7 @@ module ChainRulesCore using Base.Broadcast: broadcasted, Broadcasted, broadcastable, materialize, materialize! using Base.Meta using LinearAlgebra -using Compat: hasfield, hasproperty, ismutabletype +using Compat: hasfield, hasproperty, ismutabletype, Returns export frule, rrule # core function # rule configurations diff --git a/src/rule_definition_tools.jl b/src/rule_definition_tools.jl index e087f6dde..88ec3e8aa 100644 --- a/src/rule_definition_tools.jl +++ b/src/rule_definition_tools.jl @@ -418,14 +418,6 @@ function _nondiff_frule_expr(__source__, primal_sig_parts, primal_invoke) end end -struct NonDiffPullback{T<:Tuple{Vararg{NoTangent}}} <: Function - v::T -end - -function (@nospecialize pb::NonDiffPullback)(@nospecialize ::Any) - return pb.v -end - function tuple_expression(primal_sig_parts) has_vararg = _isvararg(primal_sig_parts[end]) return if !has_vararg @@ -444,7 +436,7 @@ function _nondiff_rrule_expr(__source__, primal_sig_parts, primal_invoke) tup_expr = tuple_expression(primal_sig_parts) primal_name = first(primal_invoke.args) pullback_expr = @strip_linenos quote - NonDiffPullback($(tup_expr)) + Returns($(tup_expr)) end @gensym kwargs diff --git a/test/rule_definition_tools.jl b/test/rule_definition_tools.jl index bac39e694..de31941ed 100644 --- a/test/rule_definition_tools.jl +++ b/test/rule_definition_tools.jl @@ -42,18 +42,6 @@ end @testset "rule_definition_tools.jl" begin @testset "@non_differentiable" begin - @testset "`NonDiffPullback`" begin - NDP = ChainRulesCore.NonDiffPullback - for i in 0:5 - tup = ntuple((_ -> NoTangent()), i) - ndp = NDP(tup) - @test ndp === @inferred NDP(tup) - @test tup === @inferred ndp(:arbitrary) - @test_throws MethodError ndp() - @test_throws MethodError ndp(1, 2) - end - end - @testset "issue #678: identical pullback objects" begin issue_678_f(::Any) = nothing issue_678_g(::Any) = nothing