From ce0a458f5e501595f60bba634e97dacb839656c2 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 7 Jan 2021 14:52:26 -0800 Subject: [PATCH 1/2] Release type constraints to take any Composite --- src/rulesets/LinearAlgebra/factorization.jl | 2 +- src/rulesets/LinearAlgebra/symmetric.jl | 4 ++-- 2 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 9b081d9e7..21bcfbaf7 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -101,7 +101,7 @@ end function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}} F = eigen(A; kwargs...) - function eigen_pullback(ΔF::Composite{<:Eigen}) + function eigen_pullback(ΔF::Composite) λ, V = F.values, F.vectors Δλ, ΔV = ΔF.values, ΔF.vectors ΔV isa AbstractZero && Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 6e92fd6a6..c8d65885a 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -115,7 +115,7 @@ function rrule( kwargs..., ) F = eigen(A; kwargs...) - function eigen_pullback(ΔF::Composite{<:Eigen}) + function eigen_pullback(ΔF::Composite) λ, U = F.values, F.vectors Δλ, ΔU = ΔF.values, ΔF.vectors ΔU = ΔU isa AbstractZero ? ΔU : copy(ΔU) @@ -227,7 +227,7 @@ end # is supported by reverse-mode AD packages function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}) F = svd(A) - function svd_pullback(ΔF::Composite{<:SVD}) + function svd_pullback(ΔF::Composite) U, V = F.U, F.V c = _svd_eigvals_sign!(similar(F.S), U, V) λ = F.S .* c From 8c390205ffe85261ce6e0cbd9a985c3be4918e51 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 7 Jan 2021 14:52:45 -0800 Subject: [PATCH 2/2] Increment version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index b036ce122..ffb057613 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.44" +version = "0.7.45" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"