From ae339b693d8d48cae2def49a1a7b620eb6591817 Mon Sep 17 00:00:00 2001 From: Frames White Date: Fri, 9 Feb 2024 18:50:35 +0800 Subject: [PATCH 1/3] Use ProjectTo in multiarg + --- src/rulesets/Base/arraymath.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 078bb602a..89ed9fe8a 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -443,10 +443,10 @@ frule((_, ΔAs...), ::typeof(+), As::AbstractArray...) = +(As...), +(ΔAs...) function rrule(::typeof(+), arrs::AbstractArray...) y = +(arrs...) - arr_axs = map(axes, arrs) + projs = map(ProjectTo, arrs) function add_pullback(dy_raw) - dy = unthunk(dy_raw) # reshape will otherwise unthunk N times - return (NoTangent(), map(ax -> reshape(dy, ax), arr_axs)...) + dy = unthunk(dy_raw) # projs will otherwise unthunk N times + return (NoTangent(), map(proj -> proj(dy), projs)...) end return y, add_pullback end From abf9e3757eb3ead01a519fdeeb6b32eb285d87f3 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 12 Feb 2024 14:33:36 +0800 Subject: [PATCH 2/3] Add tricky cases --- test/rulesets/Base/arraymath.jl | 2 ++ 1 file changed, 2 insertions(+) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 2682f3b8a..7c76aa17a 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -216,5 +216,7 @@ # rev @gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4)) @gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) + test_rrule(+, randn(3,3), Diagonal(randn(3)), randn(3,3,1)) + test_rrule(+, randn(3,3), Diagonal(randn(3)), Symmetric(randn(3,3))) end end From 5f078bfee88cc6c3da41fe91e939f2f7e8b68a36 Mon Sep 17 00:00:00 2001 From: Frames White Date: Mon, 12 Feb 2024 14:36:01 +0800 Subject: [PATCH 3/3] formatting Co-authored-by: github-actions[bot] <41898282+github-actions[bot]@users.noreply.github.com> --- test/rulesets/Base/arraymath.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 7c76aa17a..0a1416444 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -215,8 +215,8 @@ @gpu test_frule(+, randn(2), randn(2), randn(2)) # rev @gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4)) - @gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) - test_rrule(+, randn(3,3), Diagonal(randn(3)), randn(3,3,1)) - test_rrule(+, randn(3,3), Diagonal(randn(3)), Symmetric(randn(3,3))) + @gpu test_rrule(+, randn(3), randn(3, 1), randn(3, 1, 1)) + test_rrule(+, randn(3, 3), Diagonal(randn(3)), randn(3, 3, 1)) + test_rrule(+, randn(3, 3), Diagonal(randn(3)), Symmetric(randn(3, 3))) end end