diff --git a/Project.toml b/Project.toml index ec664d3a5..1eea6f38d 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.51" +version = "0.7.52" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 853c97efc..2ec9e2c97 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -307,7 +307,7 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{ hermA = Hermitian(A) ∂V = ΔV isa AbstractZero ? ΔV : copyto!(similar(ΔV), ΔV) ∂hermA = eigen_rev!(hermA, λ, V, Δλ, ∂V) - ∂Atriu = _symherm_back(typeof(hermA), ∂hermA, hermA.uplo) + ∂Atriu = _symherm_back(typeof(hermA), ∂hermA, Symbol(hermA.uplo)) ∂A = ∂Atriu isa AbstractTriangular ? triu!(∂Atriu.data) : ∂Atriu elseif ΔV isa AbstractZero ∂K = Diagonal(Δλ) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 9471b3ba5..e3ea572b4 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -8,8 +8,8 @@ end function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) Ω = T(A, uplo) - function HermOrSym_pullback(ΔΩ) - return (NO_FIELDS, _symherm_back(typeof(Ω), ΔΩ, Ω.uplo), DoesNotExist()) + @inline function HermOrSym_pullback(ΔΩ) + return (NO_FIELDS, _symherm_back(typeof(Ω), ΔΩ, uplo), DoesNotExist()) end return Ω, HermOrSym_pullback end @@ -26,7 +26,7 @@ function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym) TA = _symhermtype(A) T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)} uplo = A.uplo - ∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo) + ∂A = T∂A(_symherm_back(typeof(A), ΔΩ, Symbol(uplo)), uplo) return NO_FIELDS, ∂A end return TM(A), Matrix_pullback @@ -44,33 +44,46 @@ function _symherm_forward(A, ΔA) end # for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A -_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo) -function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo) - return _symmetric_back(ΔΩ, uplo) +@inline function _symherm_back(::Type{T}, ΔΩ, uplo::Symbol) where {T} + if T <: Symmetric + return _symmetric_back(ΔΩ, uplo) + elseif T <: Hermitian + if ΔΩ isa AbstractMatrix{<:Real} + return _symmetric_back(ΔΩ, uplo) + else + return _hermitian_back(ΔΩ, uplo) + end + end + error() end -_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo) -_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo) -function _symmetric_back(ΔΩ, uplo) +@inline function _symmetric_back(ΔΩ, uplo::Symbol) + if ΔΩ isa Diagonal + return ΔΩ + elseif ΔΩ isa LinearAlgebra.AbstractTriangular + if istriu(ΔΩ) + return Matrix(uplo === :U ? ΔΩ : transpose(ΔΩ)) + else + return Matrix(uplo === :U ? transpose(ΔΩ) : ΔΩ) + end + end L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ) - return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D + return uplo === :U ? U .+ transpose(L) - D : L .+ transpose(U) - D end -_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ -_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ)) -_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ) -function _hermitian_back(ΔΩ, uplo) - L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ)) - return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD -end -_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ) -function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo) - ∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ))) - return if istriu(ΔΩ) - return Matrix(uplo == 'U' ? ∂UL : ∂UL') - else - return Matrix(uplo == 'U' ? ∂UL' : ∂UL) +@inline function _hermitian_back(ΔΩ, uplo::Symbol) + if ΔΩ isa Diagonal + return real.(ΔΩ) + elseif ΔΩ isa LinearAlgebra.AbstractTriangular + ∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ))) + if istriu(ΔΩ) + return Matrix(uplo === :U ? ∂UL : ∂UL') + else + return Matrix(uplo === :U ? ∂UL' : ∂UL) + end end + L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ)) + return uplo === :U ? U .+ L' - rD : L .+ U' - rD end ##### diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 85d00e726..bea69302b 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -18,7 +18,7 @@ @testset "rrule" begin # on old versions of julia this combination doesn't infer but we don't care as # it infers fine on modern versions. - check_inferred = !(VERSION <= v"1.5" && T <: ComplexF64 && SymHerm <: Hermitian) + check_inferred = !(VERSION < v"1.5" && T <: ComplexF64 && SymHerm <: Hermitian) x = randn(T, N, N) ∂x = randn(T, N, N) @@ -26,14 +26,26 @@ @testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular) rrule_test( SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing); - check_inferred = check_inferred + # type stability here critically relies on uplo being constant propagated, + # so we need to test this more carefully below + check_inferred=false, ) + if check_inferred + @inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo} + return rrule(SymHerm, x, uplo)[2](ΔΩ) + end)(SymHerm, x, MT(ΔΩ), Val(uplo)) + end end @testset "back(::Diagonal)" begin rrule_test( SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing); - check_inferred = check_inferred + check_inferred=false, ) + if check_inferred + @inferred (function (SymHerm, x, ΔΩ, ::Val{uplo}) where {uplo} + return rrule(SymHerm, x, uplo)[2](ΔΩ) + end)(SymHerm, x, Diagonal(ΔΩ), Val(uplo)) + end end end end