From a05ebb28cd6379b8221111823835f8172861e248 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 20 May 2024 23:46:27 +0800 Subject: [PATCH 1/4] Add rrule for NamedTuple merge --- src/rulesets/Base/base.jl | 22 ++++++++++++++++++++++ test/rulesets/Base/base.jl | 5 +++++ 2 files changed, 27 insertions(+) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 6c66d19ee..951f62922 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -291,3 +291,25 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage end return y, task_local_storage_pullback end + + +#### +#### merge +#### + +function rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1, F2} + y = merge(nt1, nt2) + function merge_pullback(dy) + dnt1 = Tangent{typeof(nt1)}(; + (f1 => (f1 in F2 ? ZeroTangent() : getproperty(dy, f1)) for f1 in F1)... + ) + dnt2 = Tangent{typeof(nt2)}(; + (f2 => getproperty(dy, f2) for f2 in F2)... + ) + return (NoTangent(), dnt1, dnt2) + end + merge_pullback(dy::AbstractThunk) = merge_pullback(unthunk(dy)) + merge_pullback(x::AbstractZero) = (NoTangent(), x, x) + + return y, merge_pullback +end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 25c755f55..ec87640b7 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -257,4 +257,9 @@ end test_rrule(map, Multiplier(4.5), (6.7, 8.9), (0.1, 0.2, 0.3), check_inferred=false) end end + + @testset "merge NamedTuple" begin + test_rrule(merge, (;a=1.0), (;b=2.0), check_inferred=false) + test_rrule(merge, (;a=1.0), (;a=2.0), check_inferred=false) + end end From d40e282ea010621ee79115b8ed48079dcf11cbcd Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 21 May 2024 17:11:53 +0800 Subject: [PATCH 2/4] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rulesets/Base/base.jl | 8 ++------ test/rulesets/Base/base.jl | 4 ++-- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index 951f62922..ea4fb2285 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -292,24 +292,20 @@ function rrule(config::RuleConfig{>:HasReverseMode}, ::typeof(task_local_storage return y, task_local_storage_pullback end - #### #### merge #### -function rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1, F2} +function rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1,F2} y = merge(nt1, nt2) function merge_pullback(dy) dnt1 = Tangent{typeof(nt1)}(; (f1 => (f1 in F2 ? ZeroTangent() : getproperty(dy, f1)) for f1 in F1)... ) - dnt2 = Tangent{typeof(nt2)}(; - (f2 => getproperty(dy, f2) for f2 in F2)... - ) + dnt2 = Tangent{typeof(nt2)}(; (f2 => getproperty(dy, f2) for f2 in F2)...) return (NoTangent(), dnt1, dnt2) end merge_pullback(dy::AbstractThunk) = merge_pullback(unthunk(dy)) merge_pullback(x::AbstractZero) = (NoTangent(), x, x) - return y, merge_pullback end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index ec87640b7..f1029588a 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -259,7 +259,7 @@ end end @testset "merge NamedTuple" begin - test_rrule(merge, (;a=1.0), (;b=2.0), check_inferred=false) - test_rrule(merge, (;a=1.0), (;a=2.0), check_inferred=false) + test_rrule(merge, (; a=1.0), (; b=2.0); check_inferred=false) + test_rrule(merge, (; a=1.0), (; a=2.0); check_inferred=false) end end From 44f1a5eb0055a9b7304650e701c9c274055edb24 Mon Sep 17 00:00:00 2001 From: Frames White Date: Tue, 21 May 2024 17:58:57 +0800 Subject: [PATCH 3/4] Rewrite as a generated functor --- src/rulesets/Base/base.jl | 26 +++++++++++++++----------- test/rulesets/Base/base.jl | 4 ++-- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index ea4fb2285..a2b195701 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -295,17 +295,21 @@ end #### #### merge #### - -function rrule(::typeof(merge), nt1::NamedTuple{F1}, nt2::NamedTuple{F2}) where {F1,F2} - y = merge(nt1, nt2) - function merge_pullback(dy) - dnt1 = Tangent{typeof(nt1)}(; - (f1 => (f1 in F2 ? ZeroTangent() : getproperty(dy, f1)) for f1 in F1)... - ) - dnt2 = Tangent{typeof(nt2)}(; (f2 => getproperty(dy, f2) for f2 in F2)...) +# need to work around inability to return closures from generated functions +struct MergePullback{T1, T2} +end +(this::MergePullback)(dy::AbstractThunk) = this(unthunk(dy)) +(::MergePullback)(x::AbstractZero) = (NoTangent(), x, x) +@generated function(::MergePullback{T1,T2})(dy::Tangent) where {F1,T1<:NamedTuple{F1},F2,T2<:NamedTuple{F2}} + _getproperty_kwexpr(key) = :($key = getproperty(dy, $(Meta.quot(key)))) + quote + dnt1 = Tangent{T1}(; $(map(_getproperty_kwexpr, setdiff(F1, F2))...)) + dnt2 = Tangent{T2}(; $(map(_getproperty_kwexpr, F2)...)) return (NoTangent(), dnt1, dnt2) end - merge_pullback(dy::AbstractThunk) = merge_pullback(unthunk(dy)) - merge_pullback(x::AbstractZero) = (NoTangent(), x, x) - return y, merge_pullback +end + +function rrule(::typeof(merge), nt1::T1, nt2::T2) where {T1<:NamedTuple, T2<:NamedTuple} + y = merge(nt1, nt2) + return y, MergePullback{T1,T2}() end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index f1029588a..2c8412a65 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -259,7 +259,7 @@ end end @testset "merge NamedTuple" begin - test_rrule(merge, (; a=1.0), (; b=2.0); check_inferred=false) - test_rrule(merge, (; a=1.0), (; a=2.0); check_inferred=false) + test_rrule(merge, (;a=1.0), (;b=2.0)) + test_rrule(merge, (;a=1.0), (;a=2.0)) end end From 49d5ae769dcf9c01b06cebc75681c7d2800c8d3f Mon Sep 17 00:00:00 2001 From: Frames White Date: Wed, 22 May 2024 19:17:43 +0800 Subject: [PATCH 4/4] Apply suggestions from code review Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- src/rulesets/Base/base.jl | 9 +++++---- test/rulesets/Base/base.jl | 4 ++-- 2 files changed, 7 insertions(+), 6 deletions(-) diff --git a/src/rulesets/Base/base.jl b/src/rulesets/Base/base.jl index a2b195701..0c81b05cc 100644 --- a/src/rulesets/Base/base.jl +++ b/src/rulesets/Base/base.jl @@ -296,11 +296,12 @@ end #### merge #### # need to work around inability to return closures from generated functions -struct MergePullback{T1, T2} -end +struct MergePullback{T1,T2} end (this::MergePullback)(dy::AbstractThunk) = this(unthunk(dy)) (::MergePullback)(x::AbstractZero) = (NoTangent(), x, x) -@generated function(::MergePullback{T1,T2})(dy::Tangent) where {F1,T1<:NamedTuple{F1},F2,T2<:NamedTuple{F2}} +@generated function (::MergePullback{T1,T2})( + dy::Tangent +) where {F1,T1<:NamedTuple{F1},F2,T2<:NamedTuple{F2}} _getproperty_kwexpr(key) = :($key = getproperty(dy, $(Meta.quot(key)))) quote dnt1 = Tangent{T1}(; $(map(_getproperty_kwexpr, setdiff(F1, F2))...)) @@ -309,7 +310,7 @@ end end end -function rrule(::typeof(merge), nt1::T1, nt2::T2) where {T1<:NamedTuple, T2<:NamedTuple} +function rrule(::typeof(merge), nt1::T1, nt2::T2) where {T1<:NamedTuple,T2<:NamedTuple} y = merge(nt1, nt2) return y, MergePullback{T1,T2}() end diff --git a/test/rulesets/Base/base.jl b/test/rulesets/Base/base.jl index 2c8412a65..1e738a6a5 100644 --- a/test/rulesets/Base/base.jl +++ b/test/rulesets/Base/base.jl @@ -259,7 +259,7 @@ end end @testset "merge NamedTuple" begin - test_rrule(merge, (;a=1.0), (;b=2.0)) - test_rrule(merge, (;a=1.0), (;a=2.0)) + test_rrule(merge, (; a=1.0), (; b=2.0)) + test_rrule(merge, (; a=1.0), (; a=2.0)) end end