diff --git a/Project.toml b/Project.toml index 4824861e8..d94ce7a3d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "1.52.0" +version = "1.52.1" [deps] Adapt = "79e6a3ab-5dfb-504d-930d-738a2a938a0e" diff --git a/src/rulesets/Base/arraymath.jl b/src/rulesets/Base/arraymath.jl index 3a8b09984..7fbf46062 100644 --- a/src/rulesets/Base/arraymath.jl +++ b/src/rulesets/Base/arraymath.jl @@ -210,7 +210,13 @@ end # VERSION ##### `muladd` ##### -function frule((_, ΔA, ΔB, Δz), ::typeof(muladd), A, B, z) +function frule( + (_, ΔA, ΔB, Δz), + ::typeof(muladd), + A::AbstractVecOrMat{<:CommutativeMulNumber}, + B::AbstractVecOrMat{<:CommutativeMulNumber}, + z::Union{CommutativeMulNumber, AbstractVecOrMat{<:CommutativeMulNumber}} +) Ω = muladd(A, B, z) return Ω, ΔA * B .+ A * ΔB .+ Δz end diff --git a/test/rulesets/Base/arraymath.jl b/test/rulesets/Base/arraymath.jl index 847808c1f..2682f3b8a 100644 --- a/test/rulesets/Base/arraymath.jl +++ b/test/rulesets/Base/arraymath.jl @@ -85,30 +85,33 @@ @testset "muladd: $T" for T in (Float64, ComplexF64) @testset "add $(typeof(z))" for z in [rand(), rand(T, 3), rand(T, 3, 3), false] - @testset "forward mode" begin - @gpu test_frule(muladd, rand(T, 3, 5), rand(T, 5, 3), z) - end @testset "matrix * matrix" begin A = rand(T, 3, 3) B = rand(T, 3, 3) @gpu test_rrule(muladd, A, B, z) @gpu test_rrule(muladd, A', B, z) @gpu test_rrule(muladd, A , B', z) + @gpu test_frule(muladd, A, B, z) + @gpu test_frule(muladd, A', B, z) + @gpu test_frule(muladd, A , B', z) C = rand(T, 3, 5) D = rand(T, 5, 3) @gpu test_rrule(muladd, C, D, z) + @gpu test_frule(muladd, C, D, z) end if ndims(z) <= 1 @testset "matrix * vector" begin A, B = rand(T, 3, 3), rand(T, 3) test_rrule(muladd, A, B, z) test_rrule(muladd, A, B ⊢ rand(T, 3,1), z) + test_frule(muladd, A, B, z) end @testset "adjoint * matrix" begin At, B = rand(T, 3)', rand(T, 3, 3) test_rrule(muladd, At, B, z') test_rrule(muladd, At ⊢ rand(T,1,3), B, z') + test_frule(muladd, At, B, z') end end if ndims(z) == 0 @@ -116,6 +119,7 @@ A, B = rand(T, 3)', rand(T, 3) test_rrule(muladd, A, B, z) test_rrule(muladd, A ⊢ rand(T,1,3), B, z') + test_frule(muladd, A, B, z) end end if ndims(z) == 2 # other dims lead to e.g. muladd(ones(4), ones(1,4), 1) @@ -123,6 +127,7 @@ A, B = rand(T, 3), rand(T, 3)' test_rrule(muladd, A, B, z) test_rrule(muladd, A, B ⊢ rand(T,1,3), z) + test_frule(muladd, A, B, z) end end end