From 99c58a9139939eb04d17dd88926d5c0c84570d9a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 28 Feb 2021 16:28:41 -0800 Subject: [PATCH] Ensure pullback of exp works for immutable arrays (#381) * Ensure exp cotangent is mutable * Increment version number * Use convert and inplaceable trait --- Project.toml | 2 +- src/rulesets/LinearAlgebra/matfun.jl | 5 ++++- 2 files changed, 5 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 1eea6f38d..fa368706c 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.52" +version = "0.7.53" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/matfun.jl b/src/rulesets/LinearAlgebra/matfun.jl index 2b909ac89..28402e5b5 100644 --- a/src/rulesets/LinearAlgebra/matfun.jl +++ b/src/rulesets/LinearAlgebra/matfun.jl @@ -129,7 +129,10 @@ function rrule(::typeof(exp), A0::StridedMatrix{<:BlasFloat}) A = copy(A0) X, intermediates = _matfun!(exp, A) function exp_pullback(ΔX) - ∂A = _matfun_frechet_adjoint!(exp, ΔX, A, X, intermediates) + # Ensures ∂X is mutable. The outer `adjoint` is unwrapped without copy by + # the default _matfun_frechet_adjoint! + ∂X = ChainRulesCore.is_inplaceable_destination(ΔX) ? ΔX : convert(Matrix, ΔX')' + ∂A = _matfun_frechet_adjoint!(exp, ∂X, A, X, intermediates) return NO_FIELDS, ∂A end return X, exp_pullback