diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 078bb602a..fa0b34d53 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -434,6 +434,16 @@ function rrule(::typeof(-), x::AbstractArray) return -x, negation_pullback end +##### +##### Subtraction +##### + +frule((_, Δx, Δy), ::typeof(-), x::AbstractArray, y::AbstractArray) = x-y, Δx-Δy + +function rrule(::typeof(-), x::AbstractArray, y::AbstractArray) + subtract_pullback(dy) = (NoTangent(), dy, -dy) + return x-y, subtract_pullback +end ##### ##### Addition (Multiarg `+`) diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 2682f3b8a..2b4970b56 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -217,4 +217,12 @@ @gpu test_rrule(+, randn(4, 4), randn(4, 4), randn(4, 4)) @gpu test_rrule(+, randn(3), randn(3,1), randn(3,1,1)) end + + @testset "subtraction" begin + # fwd + @gpu test_frule(-, randn(2), randn(2)) + # rev + @gpu test_rrule(-, randn(4, 4), randn(4, 4)) + @gpu test_rrule(-, randn(3), randn(3,1)) + end end