From b72df42c543e57315c136862b93226f3185ea8e7 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 5 Dec 2023 17:08:34 +1300 Subject: [PATCH 1/3] =?UTF-8?q?add=20illustration=20of=20=E2=88=82self=20i?= =?UTF-8?q?n=20the=20maths/propagators=20section?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- docs/src/maths/propagators.md | 18 ++++++++++++++++++ 1 file changed, 18 insertions(+) diff --git a/docs/src/maths/propagators.md b/docs/src/maths/propagators.md index aba531f55..acfc43a44 100644 --- a/docs/src/maths/propagators.md +++ b/docs/src/maths/propagators.md @@ -179,6 +179,24 @@ So every `pushforward` takes in an extra argument, which is ignored unless the o It is common to write `function foo_pushforward(_, Δargs...)` in the case when `foo` does not have fields. Similarly every `pullback` returns an extra `∂self`, which for things without fields is `NoTangent()`, indicating there are no fields within the function itself. +Here's an example showing how to define `∂self` in an `rrule` when the primal function has +internal fields (implicit arguments): + +```julia +struct Multiplier{T} + x::T +end +(m::Multiplier)(y) = m.x * y + +function ChainRulesCore.rrule(m::Multiplier, y) + product = m(y) + function pullback(Δproduct) + ∂self = Tangent{typeof(m)}(; x = Δproduct * y') + ∂y = m.x' * Δproduct + return ∂self, ∂y + return product, pullback +end +``` ### Pushforward / Pullback summary From dc7bf6b6bc2c50b076425145855e7a833e992570 Mon Sep 17 00:00:00 2001 From: "Anthony D. Blaom" Date: Tue, 5 Dec 2023 17:18:31 +1300 Subject: [PATCH 2/3] oops. Fix missing `end` --- docs/src/maths/propagators.md | 7 ++++--- 1 file changed, 4 insertions(+), 3 deletions(-) diff --git a/docs/src/maths/propagators.md b/docs/src/maths/propagators.md index acfc43a44..74986ac83 100644 --- a/docs/src/maths/propagators.md +++ b/docs/src/maths/propagators.md @@ -191,9 +191,10 @@ end function ChainRulesCore.rrule(m::Multiplier, y) product = m(y) function pullback(Δproduct) - ∂self = Tangent{typeof(m)}(; x = Δproduct * y') - ∂y = m.x' * Δproduct - return ∂self, ∂y + ∂self = Tangent{typeof(m)}(; x = Δproduct * y') + ∂y = m.x' * Δproduct + return ∂self, ∂y + end return product, pullback end ``` From b4664d0fd446faabc1c26194e08122e7939ff2b6 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 8 Dec 2023 18:23:48 +0800 Subject: [PATCH 3/3] Update docs/src/maths/propagators.md Co-authored-by: Anthony Blaom, PhD --- docs/src/maths/propagators.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/src/maths/propagators.md b/docs/src/maths/propagators.md index 74986ac83..d212160d5 100644 --- a/docs/src/maths/propagators.md +++ b/docs/src/maths/propagators.md @@ -191,7 +191,7 @@ end function ChainRulesCore.rrule(m::Multiplier, y) product = m(y) function pullback(Δproduct) - ∂self = Tangent{typeof(m)}(; x = Δproduct * y') + ∂self = Tangent{Multiplier}(; x = Δproduct * y') ∂y = m.x' * Δproduct return ∂self, ∂y end