From b09660daa8c7455d682367dd7221a7316944c547 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 5 Dec 2020 16:30:13 -0800 Subject: [PATCH 01/62] Move symmetric rules to own file --- src/ChainRules.jl | 1 + src/rulesets/LinearAlgebra/structured.jl | 80 ------------------------ src/rulesets/LinearAlgebra/symmetric.jl | 79 +++++++++++++++++++++++ 3 files changed, 80 insertions(+), 80 deletions(-) create mode 100644 src/rulesets/LinearAlgebra/symmetric.jl diff --git a/src/ChainRules.jl b/src/ChainRules.jl index 7c795017c..0c07ad77c 100644 --- a/src/ChainRules.jl +++ b/src/ChainRules.jl @@ -44,6 +44,7 @@ include("rulesets/LinearAlgebra/utils.jl") include("rulesets/LinearAlgebra/blas.jl") include("rulesets/LinearAlgebra/dense.jl") include("rulesets/LinearAlgebra/structured.jl") +include("rulesets/LinearAlgebra/symmetric.jl") include("rulesets/LinearAlgebra/factorization.jl") include("rulesets/Random/random.jl") diff --git a/src/rulesets/LinearAlgebra/structured.jl b/src/rulesets/LinearAlgebra/structured.jl index c7a43fd7a..b80673f65 100644 --- a/src/rulesets/LinearAlgebra/structured.jl +++ b/src/rulesets/LinearAlgebra/structured.jl @@ -86,86 +86,6 @@ function rrule(::typeof(*), D::Diagonal{<:Real}, V::AbstractVector{<:Real}) return D * V, times_pullback end -##### -##### `Symmetric`/`Hermitian` -##### - -function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) - return T(A, uplo), T(ΔA, uplo) -end - -function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) - Ω = T(A, uplo) - function HermOrSym_pullback(ΔΩ) - return (NO_FIELDS, _symherm_back(T, ΔΩ, Ω.uplo), DoesNotExist()) - end - return Ω, HermOrSym_pullback -end - -function frule((_, ΔA), TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym) - return TM(A), TM(_symherm_forward(A, ΔA)) -end -function frule((_, ΔA), ::Type{Array}, A::LinearAlgebra.HermOrSym) - return Array(A), Array(_symherm_forward(A, ΔA)) -end - -function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym) - function Matrix_pullback(ΔΩ) - TA = _symhermtype(A) - T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)} - uplo = A.uplo - ∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo) - return NO_FIELDS, ∂A - end - return TM(A), Matrix_pullback -end -rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A) - -# Get type (Symmetric or Hermitian) from type or matrix -_symhermtype(::Type{<:Symmetric}) = Symmetric -_symhermtype(::Type{<:Hermitian}) = Hermitian -_symhermtype(A) = _symhermtype(typeof(A)) - -# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω -function _symherm_forward(A, ΔA) - TA = _symhermtype(A) - return if ΔA isa TA - ΔA - else - TA{eltype(ΔA),typeof(ΔA)}(ΔA, A.uplo) - end -end - -# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A -_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo) -function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo) - return _symmetric_back(ΔΩ, uplo) -end -_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo) -_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo) - -function _symmetric_back(ΔΩ, uplo) - L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ) - return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D -end -_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ -_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ)) -_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ) - -function _hermitian_back(ΔΩ, uplo) - L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ)) - return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD -end -_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ) -function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo) - ∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ))) - return if istriu(ΔΩ) - return Matrix(uplo == 'U' ? ∂UL : ∂UL') - else - return Matrix(uplo == 'U' ? ∂UL' : ∂UL) - end -end - ##### ##### `Adjoint` ##### diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl new file mode 100644 index 000000000..44c1cb740 --- /dev/null +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -0,0 +1,79 @@ +##### +##### `Symmetric`/`Hermitian` +##### + +function frule((_, ΔA, _), T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) + return T(A, uplo), T(ΔA, uplo) +end + +function rrule(T::Type{<:LinearAlgebra.HermOrSym}, A::AbstractMatrix, uplo) + Ω = T(A, uplo) + function HermOrSym_pullback(ΔΩ) + return (NO_FIELDS, _symherm_back(T, ΔΩ, Ω.uplo), DoesNotExist()) + end + return Ω, HermOrSym_pullback +end + +function frule((_, ΔA), TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym) + return TM(A), TM(_symherm_forward(A, ΔA)) +end +function frule((_, ΔA), ::Type{Array}, A::LinearAlgebra.HermOrSym) + return Array(A), Array(_symherm_forward(A, ΔA)) +end + +function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym) + function Matrix_pullback(ΔΩ) + TA = _symhermtype(A) + T∂A = TA{eltype(ΔΩ),typeof(ΔΩ)} + uplo = A.uplo + ∂A = T∂A(_symherm_back(A, ΔΩ, uplo), uplo) + return NO_FIELDS, ∂A + end + return TM(A), Matrix_pullback +end +rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A) + +# Get type (Symmetric or Hermitian) from type or matrix +_symhermtype(::Type{<:Symmetric}) = Symmetric +_symhermtype(::Type{<:Hermitian}) = Hermitian +_symhermtype(A) = _symhermtype(typeof(A)) + +# for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω +function _symherm_forward(A, ΔA) + TA = _symhermtype(A) + return if ΔA isa TA + ΔA + else + TA{eltype(ΔA),typeof(ΔA)}(ΔA, A.uplo) + end +end + +# for Ω = HermOrSym(A, uplo), pull back ΔΩ to get ∂A +_symherm_back(::Type{<:Symmetric}, ΔΩ, uplo) = _symmetric_back(ΔΩ, uplo) +function _symherm_back(::Type{<:Hermitian}, ΔΩ::AbstractMatrix{<:Real}, uplo) + return _symmetric_back(ΔΩ, uplo) +end +_symherm_back(::Type{<:Hermitian}, ΔΩ, uplo) = _hermitian_back(ΔΩ, uplo) +_symherm_back(Ω, ΔΩ, uplo) = _symherm_back(typeof(Ω), ΔΩ, uplo) + +function _symmetric_back(ΔΩ, uplo) + L, U, D = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), Diagonal(ΔΩ) + return uplo == 'U' ? U .+ transpose(L) - D : L .+ transpose(U) - D +end +_symmetric_back(ΔΩ::Diagonal, uplo) = ΔΩ +_symmetric_back(ΔΩ::UpperTriangular, uplo) = Matrix(uplo == 'U' ? ΔΩ : transpose(ΔΩ)) +_symmetric_back(ΔΩ::LowerTriangular, uplo) = Matrix(uplo == 'U' ? transpose(ΔΩ) : ΔΩ) + +function _hermitian_back(ΔΩ, uplo) + L, U, rD = LowerTriangular(ΔΩ), UpperTriangular(ΔΩ), real.(Diagonal(ΔΩ)) + return uplo == 'U' ? U .+ L' - rD : L .+ U' - rD +end +_hermitian_back(ΔΩ::Diagonal, uplo) = real.(ΔΩ) +function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo) + ∂UL = ΔΩ .- Diagonal(_extract_imag.(diag(ΔΩ))) + return if istriu(ΔΩ) + return Matrix(uplo == 'U' ? ∂UL : ∂UL') + else + return Matrix(uplo == 'U' ? ∂UL' : ∂UL) + end +end From 76bfac37550883fa4794576578997b67038e6c4e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 5 Dec 2020 16:30:28 -0800 Subject: [PATCH 02/62] Move symmetric tests to own file --- test/rulesets/LinearAlgebra/structured.jl | 42 ---------------------- test/rulesets/LinearAlgebra/symmetric.jl | 44 +++++++++++++++++++++++ test/runtests.jl | 1 + 3 files changed, 45 insertions(+), 42 deletions(-) create mode 100644 test/rulesets/LinearAlgebra/symmetric.jl diff --git a/test/rulesets/LinearAlgebra/structured.jl b/test/rulesets/LinearAlgebra/structured.jl index 559ac0094..1dc04e487 100644 --- a/test/rulesets/LinearAlgebra/structured.jl +++ b/test/rulesets/LinearAlgebra/structured.jl @@ -104,48 +104,6 @@ end end end - @testset "$(SymHerm)(::AbstractMatrix{$T}, :$(uplo))" for - SymHerm in (Symmetric, Hermitian), - T in (Float64, ComplexF64), - uplo in (:U, :L) - - N = 3 - @testset "frule" begin - x = randn(T, N, N) - Δx = randn(T, N, N) - # can't use frule_test here because it doesn't yet ignore nothing tangents - Ω = SymHerm(x, uplo) - Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo) - @test Ω_ad == Ω - ∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx)) - @test ∂Ω_ad ≈ ∂Ω_fd - end - @testset "rrule" begin - x = randn(T, N, N) - ∂x = randn(T, N, N) - ΔΩ = randn(T, N, N) - @testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular) - rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing)) - end - @testset "back(::Diagonal)" begin - rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing)) - end - end - end - @testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array), - SymHerm in (Symmetric, Hermitian), - T in (Float64, ComplexF64), - uplo in (:U, :L) - - N = 3 - x = SymHerm(randn(T, N, N), uplo) - Δx = randn(T, N, N) - ∂x = SymHerm(randn(T, N, N), uplo) - ΔΩ = f(SymHerm(randn(T, N, N), uplo)) - frule_test(f, (x, Δx)) - frule_test(f, (x, SymHerm(Δx, uplo))) - rrule_test(f, ΔΩ, (x, ∂x)) - end @testset "$f" for f in (Adjoint, adjoint, Transpose, transpose) n = 5 m = 3 diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl new file mode 100644 index 000000000..f76c28836 --- /dev/null +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -0,0 +1,44 @@ +@testset "Symmetric/Hermitian rules" begin + @testset "$(SymHerm)(::AbstractMatrix{$T}, :$(uplo))" for + SymHerm in (Symmetric, Hermitian), + T in (Float64, ComplexF64), + uplo in (:U, :L) + + N = 3 + @testset "frule" begin + x = randn(T, N, N) + Δx = randn(T, N, N) + # can't use frule_test here because it doesn't yet ignore nothing tangents + Ω = SymHerm(x, uplo) + Ω_ad, ∂Ω_ad = frule((Zero(), Δx, Zero()), SymHerm, x, uplo) + @test Ω_ad == Ω + ∂Ω_fd = jvp(_fdm, z -> SymHerm(z, uplo), (x, Δx)) + @test ∂Ω_ad ≈ ∂Ω_fd + end + @testset "rrule" begin + x = randn(T, N, N) + ∂x = randn(T, N, N) + ΔΩ = randn(T, N, N) + @testset "back(::$MT)" for MT in (Matrix, LowerTriangular, UpperTriangular) + rrule_test(SymHerm, MT(ΔΩ), (x, ∂x), (uplo, nothing)) + end + @testset "back(::Diagonal)" begin + rrule_test(SymHerm, Diagonal(ΔΩ), (x, Diagonal(∂x)), (uplo, nothing)) + end + end + end + @testset "$(f)(::$(SymHerm){$T}) with uplo=:$uplo" for f in (Matrix, Array), + SymHerm in (Symmetric, Hermitian), + T in (Float64, ComplexF64), + uplo in (:U, :L) + + N = 3 + x = SymHerm(randn(T, N, N), uplo) + Δx = randn(T, N, N) + ∂x = SymHerm(randn(T, N, N), uplo) + ΔΩ = f(SymHerm(randn(T, N, N), uplo)) + frule_test(f, (x, Δx)) + frule_test(f, (x, SymHerm(Δx, uplo))) + rrule_test(f, ΔΩ, (x, ∂x)) + end +end diff --git a/test/runtests.jl b/test/runtests.jl index 2daacf89f..99bb6b251 100644 --- a/test/runtests.jl +++ b/test/runtests.jl @@ -42,6 +42,7 @@ println("Testing ChainRules.jl") @testset "LinearAlgebra" begin include_test("rulesets/LinearAlgebra/dense.jl") include_test("rulesets/LinearAlgebra/structured.jl") + include_test("rulesets/LinearAlgebra/symmetric.jl") include_test("rulesets/LinearAlgebra/factorization.jl") include_test("rulesets/LinearAlgebra/blas.jl") end From 539e9d4f416fb0f046235cd14fcaa0bfb591500a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 5 Dec 2020 17:34:05 -0800 Subject: [PATCH 03/62] Adapt eigen and eigvals rules from #321 --- src/rulesets/LinearAlgebra/symmetric.jl | 94 +++++++++++++++++++++++++ 1 file changed, 94 insertions(+) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 44c1cb740..cc092b35d 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -77,3 +77,97 @@ function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo) return Matrix(uplo == 'U' ? ∂UL' : ∂UL) end end + +##### +##### `eigen` +##### + +function frule((_, ΔA), ::typeof(eigen), A::LinearAlgebra.RealHermSymComplexHerm; sortby::Union{Function,Nothing}=nothing) + F = eigen(A; sortby=sortby) + ΔA isa AbstractZero && return F, ΔA + λ, U = F.values, F.vectors + tmp = U' * ΔA + ∂K = tmp * U + ∂Kdiag = @view ∂K[diagind(∂K)] + ∂λ = real.(∂Kdiag) + ∂K ./= λ' .- λ + fill!(∂Kdiag, 0) + ∂U = mul!(tmp, U, ∂K) + _eigen_norm_phase_fwd!(∂U, A, U) + ∂F = Composite{typeof(F)}(values = ∂λ, vectors = ∂U) + return F, ∂F +end + +function rrule(::typeof(eigen), A::LinearAlgebra.RealHermSymComplexHerm; sortby::Union{Function,Nothing}=nothing) + F = eigen(A; sortby=sortby) + function eigen_pullback(ΔF::Composite{<:Eigen}) + λ, U = F.values, F.vectors + Δλ, ΔU = ΔF.values, ΔF.vectors + if ΔU isa AbstractZero + Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔU) + ∂K = Diagonal(Δλ) + ∂A = U * ∂K * U' + else + ∂U = copyto!(similar(ΔU), ΔU) + _eigen_norm_phase_rev!(∂U, A, U) + ∂K = U' * ∂U + ∂K ./= λ' .- λ + ∂K[diagind(∂K)] = Δλ + ∂A = mul!(∂K, U * ∂K, U') + end + return NO_FIELDS, ∂A + end + eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + return F, eigen_pullback +end + +_eigen_norm_phase_fwd!(∂V, ::LinearAlgebra.RealHermSym, V) = ∂V +function _eigen_norm_phase_fwd!(∂V, A::Hermitian, V) + k = A.uplo === 'U' ? size(A, 1) : 1 + @inbounds for i in axes(V, 2) + vᵢ = @view V[:, i] + vₖᵢ, ∂vₖᵢ = real(vᵢ[k]), ∂V[k, i] + ∂vᵢ .-= vᵢ .* (imag(∂vₖᵢ) / ifelse(iszero(vₖᵢ), one(vₖᵢ), vₖᵢ)) + end + return ∂V +end + +_eigen_norm_phase_rev!(∂V, ::LinearAlgebra.RealHermSym, V) = ∂V +function _eigen_norm_phase_rev!(∂V, A::Hermitian, V) + k = A.uplo === 'U' ? size(A, 1) : 1 + @inbounds for i in axes(V, 2) + vᵢ, ∂vᵢ = @views V[:, i], ∂V[:, i] + vₖᵢ = real(vᵢ[k]) + ∂cᵢ = dot(vᵢ, ∂vᵢ) + ∂vᵢ[k] -= im * (imag(∂cᵢ) / ifelse(iszero(vₖᵢ), one(vₖᵢ), vₖᵢ)) + end + return ∂V +end + +##### +##### `eigvals` +##### + +function frule((_, ΔA), ::typeof(eigvals), A::LinearAlgebra.RealHermSymComplexHerm) + ΔA isa AbstractZero && return eigvals(A), ΔA + F = eigen(A) + λ, U = F.values, F.vectors + tmp = ΔA * U + ∂λ = similar(λ) + @inbounds for i in eachindex(λ) + ∂λ[i] = real(dot(U[:, i], tmp[:, i])) + end + return λ, ∂λ +end + +function rrule(::typeof(eigvals), A::LinearAlgebra.RealHermSymComplexHerm) + F = eigen(A) + λ = F.values + function eigvals_pullback(Δλ) + U = F.vectors + ∂A = U * Diagonal(Δλ) * U' + return NO_FIELDS, ∂A + end + eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ) + return λ, eigvals_pullback +end From 7071e1e4a3ff458e478b93f8fd56bd9c383508b1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 5 Dec 2020 18:34:38 -0800 Subject: [PATCH 04/62] Don't allocate --- src/rulesets/LinearAlgebra/symmetric.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index cc092b35d..7ec5318a2 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -153,9 +153,10 @@ function frule((_, ΔA), ::typeof(eigvals), A::LinearAlgebra.RealHermSymComplexH F = eigen(A) λ, U = F.values, F.vectors tmp = ΔA * U + # diag(U' * tmp) without computing matrix product ∂λ = similar(λ) @inbounds for i in eachindex(λ) - ∂λ[i] = real(dot(U[:, i], tmp[:, i])) + ∂λ[i] = @views real(dot(U[:, i], tmp[:, i])) end return λ, ∂λ end From 1090482429d7f3e5a504a22c4105c5eed82df01a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 5 Dec 2020 23:18:12 -0800 Subject: [PATCH 05/62] Implement mutating forms for frule --- src/rulesets/LinearAlgebra/symmetric.jl | 26 +++++++++++++++++-------- 1 file changed, 18 insertions(+), 8 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 7ec5318a2..516134603 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -79,15 +79,20 @@ function _hermitian_back(ΔΩ::LinearAlgebra.AbstractTriangular, uplo) end ##### -##### `eigen` +##### `eigen!`/`eigen` ##### -function frule((_, ΔA), ::typeof(eigen), A::LinearAlgebra.RealHermSymComplexHerm; sortby::Union{Function,Nothing}=nothing) - F = eigen(A; sortby=sortby) +function frule( + (_, ΔA), + ::typeof(eigen!), + A::LinearAlgebra.RealHermSymComplexHerm{<:BlasReal,<:StridedMatrix}; + sortby::Union{Function,Nothing}=nothing, +) + F = eigen!(A; sortby=sortby) ΔA isa AbstractZero && return F, ΔA λ, U = F.values, F.vectors tmp = U' * ΔA - ∂K = tmp * U + ∂K = mul!(ΔA.data, tmp, U) ∂Kdiag = @view ∂K[diagind(∂K)] ∂λ = real.(∂Kdiag) ∂K ./= λ' .- λ @@ -145,12 +150,17 @@ function _eigen_norm_phase_rev!(∂V, A::Hermitian, V) end ##### -##### `eigvals` +##### `eigvals!`/`eigvals` ##### -function frule((_, ΔA), ::typeof(eigvals), A::LinearAlgebra.RealHermSymComplexHerm) - ΔA isa AbstractZero && return eigvals(A), ΔA - F = eigen(A) +function frule( + (_, ΔA), + ::typeof(eigvals!), + A::LinearAlgebra.RealHermSymComplexHerm{<:BlasReal,<:StridedMatrix}; + sortby::Union{Function,Nothing}=nothing, +) + ΔA isa AbstractZero && return eigvals!(A; sortby=sortby), ΔA + F = eigen!(A; sortby=sortby) λ, U = F.values, F.vectors tmp = ΔA * U # diag(U' * tmp) without computing matrix product From bbd86bbcfd85abe4cf0a848856b5ef0a24fbdbe1 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 5 Dec 2020 23:18:25 -0800 Subject: [PATCH 06/62] Add sortby keyword --- src/rulesets/LinearAlgebra/symmetric.jl | 14 +++++++++++--- 1 file changed, 11 insertions(+), 3 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 516134603..6bd4f1241 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -103,7 +103,11 @@ function frule( return F, ∂F end -function rrule(::typeof(eigen), A::LinearAlgebra.RealHermSymComplexHerm; sortby::Union{Function,Nothing}=nothing) +function rrule( + ::typeof(eigen), + A::LinearAlgebra.RealHermSymComplexHerm; + sortby::Union{Function,Nothing}=nothing, +) F = eigen(A; sortby=sortby) function eigen_pullback(ΔF::Composite{<:Eigen}) λ, U = F.values, F.vectors @@ -171,8 +175,12 @@ function frule( return λ, ∂λ end -function rrule(::typeof(eigvals), A::LinearAlgebra.RealHermSymComplexHerm) - F = eigen(A) +function rrule( + ::typeof(eigvals), + A::LinearAlgebra.RealHermSymComplexHerm; + sortby::Union{Function,Nothing}=nothing, +) + F = eigen(A; sortby=sortby) λ = F.values function eigvals_pullback(Δλ) U = F.vectors From 1d7dc7ccd54fc48a30c81a5bcf0e7db618a45c60 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 5 Dec 2020 23:18:38 -0800 Subject: [PATCH 07/62] Use fewer indices --- src/rulesets/LinearAlgebra/symmetric.jl | 14 +++++++------- 1 file changed, 7 insertions(+), 7 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 6bd4f1241..c0a0ee060 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -134,9 +134,9 @@ _eigen_norm_phase_fwd!(∂V, ::LinearAlgebra.RealHermSym, V) = ∂V function _eigen_norm_phase_fwd!(∂V, A::Hermitian, V) k = A.uplo === 'U' ? size(A, 1) : 1 @inbounds for i in axes(V, 2) - vᵢ = @view V[:, i] - vₖᵢ, ∂vₖᵢ = real(vᵢ[k]), ∂V[k, i] - ∂vᵢ .-= vᵢ .* (imag(∂vₖᵢ) / ifelse(iszero(vₖᵢ), one(vₖᵢ), vₖᵢ)) + v = @view V[:, i] + vₖ, ∂vₖ = real(v[k]), ∂V[k, i] + ∂v .-= v .* (imag(∂vₖ) / ifelse(iszero(vₖ), one(vₖ), vₖ)) end return ∂V end @@ -145,10 +145,10 @@ _eigen_norm_phase_rev!(∂V, ::LinearAlgebra.RealHermSym, V) = ∂V function _eigen_norm_phase_rev!(∂V, A::Hermitian, V) k = A.uplo === 'U' ? size(A, 1) : 1 @inbounds for i in axes(V, 2) - vᵢ, ∂vᵢ = @views V[:, i], ∂V[:, i] - vₖᵢ = real(vᵢ[k]) - ∂cᵢ = dot(vᵢ, ∂vᵢ) - ∂vᵢ[k] -= im * (imag(∂cᵢ) / ifelse(iszero(vₖᵢ), one(vₖᵢ), vₖᵢ)) + v, ∂v = @views V[:, i], ∂V[:, i] + vₖ = real(v[k]) + ∂c = dot(v, ∂v) + ∂v[k] -= im * (imag(∂c) / ifelse(iszero(vₖ), one(vₖ), vₖ)) end return ∂V end From 33e10da59189cbf31ef12fcb66b9aaae0d3b0d94 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 02:18:15 -0800 Subject: [PATCH 08/62] Correctly reference BlasReal --- src/rulesets/LinearAlgebra/symmetric.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index c0a0ee060..3fa9ecc2f 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -85,7 +85,7 @@ end function frule( (_, ΔA), ::typeof(eigen!), - A::LinearAlgebra.RealHermSymComplexHerm{<:BlasReal,<:StridedMatrix}; + A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}; sortby::Union{Function,Nothing}=nothing, ) F = eigen!(A; sortby=sortby) @@ -160,7 +160,7 @@ end function frule( (_, ΔA), ::typeof(eigvals!), - A::LinearAlgebra.RealHermSymComplexHerm{<:BlasReal,<:StridedMatrix}; + A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}; sortby::Union{Function,Nothing}=nothing, ) ΔA isa AbstractZero && return eigvals!(A; sortby=sortby), ΔA From bbfced7c4f8f9f0061c5f81a8c4754fb0ab669b8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 02:18:58 -0800 Subject: [PATCH 09/62] Move eigen pullback to external function --- src/rulesets/LinearAlgebra/symmetric.jl | 29 +++++++++++++++---------- 1 file changed, 17 insertions(+), 12 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 3fa9ecc2f..6cecb4040 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -112,24 +112,29 @@ function rrule( function eigen_pullback(ΔF::Composite{<:Eigen}) λ, U = F.values, F.vectors Δλ, ΔU = ΔF.values, ΔF.vectors - if ΔU isa AbstractZero - Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔU) - ∂K = Diagonal(Δλ) - ∂A = U * ∂K * U' - else - ∂U = copyto!(similar(ΔU), ΔU) - _eigen_norm_phase_rev!(∂U, A, U) - ∂K = U' * ∂U - ∂K ./= λ' .- λ - ∂K[diagind(∂K)] = Δλ - ∂A = mul!(∂K, U * ∂K, U') - end + ∂A = eigen_rev(A, λ, U, Δλ, ΔU) return NO_FIELDS, ∂A end eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) return F, eigen_pullback end +function eigen_rev(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) + if ∂U isa AbstractZero + ∂λ isa AbstractZero && return (NO_FIELDS, ∂λ + ∂U) + ∂K = Diagonal(∂λ) + ∂A = U * ∂K * U' + else + ∂U = copyto!(similar(∂U), ∂U) + _eigen_norm_phase_rev!(∂U, A, U) + ∂K = U' * ∂U + ∂K ./= λ' .- λ + ∂K[diagind(∂K)] = ∂λ + ∂A = mul!(∂K, U * ∂K, U') + end + return ∂A +end + _eigen_norm_phase_fwd!(∂V, ::LinearAlgebra.RealHermSym, V) = ∂V function _eigen_norm_phase_fwd!(∂V, A::Hermitian, V) k = A.uplo === 'U' ? size(A, 1) : 1 From c3d6cfe9e0f369e265b95453c270c3455c71ccbc Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 02:19:10 -0800 Subject: [PATCH 10/62] Add hermitian svd rrule --- src/rulesets/LinearAlgebra/symmetric.jl | 27 +++++++++++++++++++++++++ 1 file changed, 27 insertions(+) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 6cecb4040..37b59492f 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -195,3 +195,30 @@ function rrule( eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ) return λ, eigvals_pullback end + +##### +##### `svd` +##### + +function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm) + F = svd(A) + function svd_pullback(ΔF::Composite{<:SVD}) + U, Vt = F.U, F.Vt + ∂S = ΔF.S + # recreate sign difference between U and Vt + n = size(U, 1) + c = similar(F.S, Int) + @inbounds broadcast!(c, eachindex(c)) do i + u = @views U[:, i] + # find element not close to zero + # at least one element has abs2 ≥ 1/n > 1/(n + 1) + k = findfirst(x -> (n + 1) * abs2(x) ≥ 1, u) + return sign(real(u[k]) * real(Vt[k, i])) + end + ∂U = ΔF.U .+ (ΔF.Vt .+ ΔF.V') .* c' + ∂A = eigen_rev(A, F.S .* c, U, ∂S .* c, ∂U) + return NO_FIELDS, ∂A + end + svd_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + return F, svd_pullback +end From a1fa32fc91e96bef13f590e7bbf8e5b47e9ffd27 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 02:34:02 -0800 Subject: [PATCH 11/62] Hermitrize in pullback --- src/rulesets/LinearAlgebra/symmetric.jl | 17 +++++++++++++++-- 1 file changed, 15 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 37b59492f..06c55eeb0 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -132,7 +132,7 @@ function eigen_rev(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) ∂K[diagind(∂K)] = ∂λ ∂A = mul!(∂K, U * ∂K, U') end - return ∂A + return _hermitrize!(∂A, A) end _eigen_norm_phase_fwd!(∂V, ::LinearAlgebra.RealHermSym, V) = ∂V @@ -189,7 +189,7 @@ function rrule( λ = F.values function eigvals_pullback(Δλ) U = F.vectors - ∂A = U * Diagonal(Δλ) * U' + ∂A = _hermitrize!(U * Diagonal(Δλ) * U', A) return NO_FIELDS, ∂A end eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ) @@ -222,3 +222,16 @@ function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm) svd_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) return F, svd_pullback end +##### +##### utilities +##### + +# in-place hermitrize matrix, optionally wrapping like A +function _hermitrize!(A) + A .= (A .+ A') ./ 2 + return A +end +function _hermitrize!(∂A, A) + _hermitrize!(∂A) + return _symhermtype(A)(∂A, Symbol(A.uplo)) +end From 44ac18d0d5d7e459c31035f67f8cb08bba192315 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 02:36:50 -0800 Subject: [PATCH 12/62] Separate out eigvals sign code --- src/rulesets/LinearAlgebra/symmetric.jl | 25 +++++++++++++++---------- 1 file changed, 15 insertions(+), 10 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 06c55eeb0..b22da506d 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -205,16 +205,7 @@ function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm) function svd_pullback(ΔF::Composite{<:SVD}) U, Vt = F.U, F.Vt ∂S = ΔF.S - # recreate sign difference between U and Vt - n = size(U, 1) - c = similar(F.S, Int) - @inbounds broadcast!(c, eachindex(c)) do i - u = @views U[:, i] - # find element not close to zero - # at least one element has abs2 ≥ 1/n > 1/(n + 1) - k = findfirst(x -> (n + 1) * abs2(x) ≥ 1, u) - return sign(real(u[k]) * real(Vt[k, i])) - end + c = _svd_eigvals_sign!(similar(F.S), U, Vt) ∂U = ΔF.U .+ (ΔF.Vt .+ ΔF.V') .* c' ∂A = eigen_rev(A, F.S .* c, U, ∂S .* c, ∂U) return NO_FIELDS, ∂A @@ -222,6 +213,20 @@ function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm) svd_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) return F, svd_pullback end + +# given singular vectors, compute sign of eigenvalues corresponding to singular values +function _svd_eigvals_sign!(c, U, Vt) + n = size(U, 1) + @inbounds broadcast!(c, eachindex(c)) do i + u = @views U[:, i] + # find element not close to zero + # at least one element has abs2 ≥ 1/n > 1/(n + 1) + k = findfirst(x -> (n + 1) * abs2(x) ≥ 1, u) + return sign(real(u[k]) * real(Vt[k, i])) + end + return c +end + ##### ##### utilities ##### From a211da1a0a66fea5a24413a5b12ef5c71a60daf6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sun, 6 Dec 2020 03:40:54 -0800 Subject: [PATCH 13/62] Add svdvals rrule --- src/rulesets/LinearAlgebra/symmetric.jl | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index b22da506d..fa585c634 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -227,6 +227,19 @@ function _svd_eigvals_sign!(c, U, Vt) return c end +function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm) + # sorting doesn't affect the eigvals pullback, and it simplifies this rrule + λ, back = rrule(eigvals, A; sortby = x -> -abs2(x)) + S = abs.(λ) + function svdvals_pullback(ΔS) + ∂λ = ΔS .* S ./ ifelse.(iszero.(λ), one.(λ), λ) + ∂A = back(∂λ) + return NO_FIELDS, ∂A + end + svdvals_pullback(ΔS::AbstractZero) = (NO_FIELDS, ΔS) + return S, svdvals_pullback +end + ##### ##### utilities ##### From 3a8ffc22ee1465589bef11f4872588aec4114aa0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 7 Dec 2020 14:19:05 -0800 Subject: [PATCH 14/62] Rearrange functions --- src/rulesets/LinearAlgebra/symmetric.jl | 10 +++++----- 1 file changed, 5 insertions(+), 5 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index fa585c634..f5ca7b297 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -33,11 +33,6 @@ function rrule(TM::Type{<:Matrix}, A::LinearAlgebra.HermOrSym) end rrule(::Type{Array}, A::LinearAlgebra.HermOrSym) = rrule(Matrix, A) -# Get type (Symmetric or Hermitian) from type or matrix -_symhermtype(::Type{<:Symmetric}) = Symmetric -_symhermtype(::Type{<:Hermitian}) = Hermitian -_symhermtype(A) = _symhermtype(typeof(A)) - # for Ω = Matrix(A::HermOrSym), push forward ΔA to get ∂Ω function _symherm_forward(A, ΔA) TA = _symhermtype(A) @@ -244,6 +239,11 @@ end ##### utilities ##### +# Get type (Symmetric or Hermitian) from type or matrix +_symhermtype(::Type{<:Symmetric}) = Symmetric +_symhermtype(::Type{<:Hermitian}) = Hermitian +_symhermtype(A) = _symhermtype(typeof(A)) + # in-place hermitrize matrix, optionally wrapping like A function _hermitrize!(A) A .= (A .+ A') ./ 2 From de4b3a10a35b2f1f8cf764724c746bebdef5ee81 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 7 Dec 2020 17:43:16 -0800 Subject: [PATCH 15/62] Realify eigenvalue cotangents --- src/rulesets/LinearAlgebra/symmetric.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index f5ca7b297..cd7bda453 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -124,7 +124,7 @@ function eigen_rev(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) _eigen_norm_phase_rev!(∂U, A, U) ∂K = U' * ∂U ∂K ./= λ' .- λ - ∂K[diagind(∂K)] = ∂λ + ∂K[diagind(∂K)] .= real.(∂λ) ∂A = mul!(∂K, U * ∂K, U') end return _hermitrize!(∂A, A) @@ -184,7 +184,8 @@ function rrule( λ = F.values function eigvals_pullback(Δλ) U = F.vectors - ∂A = _hermitrize!(U * Diagonal(Δλ) * U', A) + ∂A = similar(A) + mul!(∂A.data, U .* real.(Δλ'), U') return NO_FIELDS, ∂A end eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ) From 1004e8889d2ed5470a5b2121325bcdce6adeea32 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 7 Dec 2020 19:36:45 -0800 Subject: [PATCH 16/62] Restrict to StridedMatrixes --- src/rulesets/LinearAlgebra/symmetric.jl | 8 ++++---- 1 file changed, 4 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index cd7bda453..d9e0b1990 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -100,7 +100,7 @@ end function rrule( ::typeof(eigen), - A::LinearAlgebra.RealHermSymComplexHerm; + A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}; sortby::Union{Function,Nothing}=nothing, ) F = eigen(A; sortby=sortby) @@ -177,7 +177,7 @@ end function rrule( ::typeof(eigvals), - A::LinearAlgebra.RealHermSymComplexHerm; + A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}; sortby::Union{Function,Nothing}=nothing, ) F = eigen(A; sortby=sortby) @@ -196,7 +196,7 @@ end ##### `svd` ##### -function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm) +function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}) F = svd(A) function svd_pullback(ΔF::Composite{<:SVD}) U, Vt = F.U, F.Vt @@ -223,7 +223,7 @@ function _svd_eigvals_sign!(c, U, Vt) return c end -function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm) +function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}) # sorting doesn't affect the eigvals pullback, and it simplifies this rrule λ, back = rrule(eigvals, A; sortby = x -> -abs2(x)) S = abs.(λ) From 202700132dc61613669087bfe22addd80cd10e99 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 7 Dec 2020 19:38:42 -0800 Subject: [PATCH 17/62] Reduce allocations and unnecessary ops --- src/rulesets/LinearAlgebra/symmetric.jl | 20 ++++++++++++-------- 1 file changed, 12 insertions(+), 8 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index d9e0b1990..a1f451cc0 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -115,19 +115,23 @@ function rrule( end function eigen_rev(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) + ∂λ isa AbstractZero && ∂λ isa AbstractZero && return (NO_FIELDS, ∂λ + ∂U) + ∂A = similar(A, eltype(U)) + tmp = similar(U) if ∂U isa AbstractZero - ∂λ isa AbstractZero && return (NO_FIELDS, ∂λ + ∂U) - ∂K = Diagonal(∂λ) - ∂A = U * ∂K * U' + tmp .= U .* real.(∂λ) + mul!(∂A.data, U, tmp') else - ∂U = copyto!(similar(∂U), ∂U) + ∂U = copyto!(tmp, ∂U) _eigen_norm_phase_rev!(∂U, A, U) - ∂K = U' * ∂U + ∂K = mul!(∂A.data, U', ∂U) ∂K ./= λ' .- λ ∂K[diagind(∂K)] .= real.(∂λ) - ∂A = mul!(∂K, U * ∂K, U') + mul!(tmp, ∂K, U') + mul!(∂A.data, U, tmp) + _hermitrize!(∂A.data) end - return _hermitrize!(∂A, A) + return ∂A end _eigen_norm_phase_fwd!(∂V, ::LinearAlgebra.RealHermSym, V) = ∂V @@ -185,7 +189,7 @@ function rrule( function eigvals_pullback(Δλ) U = F.vectors ∂A = similar(A) - mul!(∂A.data, U .* real.(Δλ'), U') + mul!(∂A.data, U, (U .* real.(Δλ))') return NO_FIELDS, ∂A end eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ) From 429025cf71c9394d68d90289aaa05cf2bdbaea8b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 7 Dec 2020 19:45:35 -0800 Subject: [PATCH 18/62] Avoid unnecessary allocation in svd --- src/rulesets/LinearAlgebra/symmetric.jl | 13 ++++++------- 1 file changed, 6 insertions(+), 7 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index a1f451cc0..2ba62b729 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -107,22 +107,21 @@ function rrule( function eigen_pullback(ΔF::Composite{<:Eigen}) λ, U = F.values, F.vectors Δλ, ΔU = ΔF.values, ΔF.vectors - ∂A = eigen_rev(A, λ, U, Δλ, ΔU) + ∂A = eigen_rev!(A, λ, U, Δλ, copy(ΔU)) return NO_FIELDS, ∂A end eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) return F, eigen_pullback end -function eigen_rev(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) +# ∂U is overwritten if not an `AbstractZero` +function eigen_rev!(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) ∂λ isa AbstractZero && ∂λ isa AbstractZero && return (NO_FIELDS, ∂λ + ∂U) ∂A = similar(A, eltype(U)) - tmp = similar(U) + tmp = ∂U if ∂U isa AbstractZero - tmp .= U .* real.(∂λ) - mul!(∂A.data, U, tmp') + mul!(∂A.data, U, (U .* real.(∂λ))') else - ∂U = copyto!(tmp, ∂U) _eigen_norm_phase_rev!(∂U, A, U) ∂K = mul!(∂A.data, U', ∂U) ∂K ./= λ' .- λ @@ -207,7 +206,7 @@ function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.Bla ∂S = ΔF.S c = _svd_eigvals_sign!(similar(F.S), U, Vt) ∂U = ΔF.U .+ (ΔF.Vt .+ ΔF.V') .* c' - ∂A = eigen_rev(A, F.S .* c, U, ∂S .* c, ∂U) + ∂A = eigen_rev!(A, F.S .* c, U, ∂S .* c, ∂U) return NO_FIELDS, ∂A end svd_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) From efb60705fc7fa23532dc1e094db8c4473fb05d7b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 7 Dec 2020 21:02:32 -0800 Subject: [PATCH 19/62] Explicitly create eigvals pullback inputs --- src/rulesets/LinearAlgebra/symmetric.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 2ba62b729..f422a2442 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -203,10 +203,11 @@ function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.Bla F = svd(A) function svd_pullback(ΔF::Composite{<:SVD}) U, Vt = F.U, F.Vt - ∂S = ΔF.S c = _svd_eigvals_sign!(similar(F.S), U, Vt) + λ = F.S .* c + ∂λ = ΔF.S .* c ∂U = ΔF.U .+ (ΔF.Vt .+ ΔF.V') .* c' - ∂A = eigen_rev!(A, F.S .* c, U, ∂S .* c, ∂U) + ∂A = eigen_rev!(A, λ, U, ∂λ, ∂U) return NO_FIELDS, ∂A end svd_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) From c52ab77768824ca947db5771057019ff8c8b7990 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 7 Dec 2020 21:34:00 -0800 Subject: [PATCH 20/62] Simplify _hermitrize! --- src/rulesets/LinearAlgebra/symmetric.jl | 17 ++++++++++------- 1 file changed, 10 insertions(+), 7 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index f422a2442..dd9e36c25 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -128,7 +128,7 @@ function eigen_rev!(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) ∂K[diagind(∂K)] .= real.(∂λ) mul!(tmp, ∂K, U') mul!(∂A.data, U, tmp) - _hermitrize!(∂A.data) + @inbounds _hermitrize!(∂A.data) end return ∂A end @@ -249,12 +249,15 @@ _symhermtype(::Type{<:Symmetric}) = Symmetric _symhermtype(::Type{<:Hermitian}) = Hermitian _symhermtype(A) = _symhermtype(typeof(A)) -# in-place hermitrize matrix, optionally wrapping like A +# in-place hermitrize matrix function _hermitrize!(A) - A .= (A .+ A') ./ 2 + n = size(A, 1) + for i in 1:n + for j in (i + 1):n + A[i, j] = (A[i, j] + conj(A[j, i])) / 2 + A[j, i] = conj(A[i, j]) + end + A[i, i] = real(A[i, i]) + end return A end -function _hermitrize!(∂A, A) - _hermitrize!(∂A) - return _symhermtype(A)(∂A, Symbol(A.uplo)) -end From df165e5f811ba61f53162de91200850ebeeb5192 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 7 Dec 2020 21:34:09 -0800 Subject: [PATCH 21/62] Add svdvals section --- src/rulesets/LinearAlgebra/symmetric.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index dd9e36c25..4f34ea8c7 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -227,6 +227,10 @@ function _svd_eigvals_sign!(c, U, Vt) return c end +##### +##### `svdvals` +##### + function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}) # sorting doesn't affect the eigvals pullback, and it simplifies this rrule λ, back = rrule(eigvals, A; sortby = x -> -abs2(x)) From 19d2450b6aa9df072013f0b836b0284eede7a3bd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 7 Dec 2020 21:44:33 -0800 Subject: [PATCH 22/62] Don't use convenience type not in v1.0 --- src/rulesets/LinearAlgebra/symmetric.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 4f34ea8c7..bfa7397ee 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -133,7 +133,7 @@ function eigen_rev!(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) return ∂A end -_eigen_norm_phase_fwd!(∂V, ::LinearAlgebra.RealHermSym, V) = ∂V +_eigen_norm_phase_fwd!(∂V, ::Union{Symmetric{T,S},Hermitian{T,S}}, V) where {T<:Real,S} = ∂V function _eigen_norm_phase_fwd!(∂V, A::Hermitian, V) k = A.uplo === 'U' ? size(A, 1) : 1 @inbounds for i in axes(V, 2) @@ -144,7 +144,7 @@ function _eigen_norm_phase_fwd!(∂V, A::Hermitian, V) return ∂V end -_eigen_norm_phase_rev!(∂V, ::LinearAlgebra.RealHermSym, V) = ∂V +_eigen_norm_phase_rev!(∂V, ::Union{Symmetric{T,S},Hermitian{T,S}}, V) where {T<:Real,S} = ∂V function _eigen_norm_phase_rev!(∂V, A::Hermitian, V) k = A.uplo === 'U' ? size(A, 1) : 1 @inbounds for i in axes(V, 2) From c472bb1ec9e5afc072053a4e57b9479a31555854 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 8 Dec 2020 02:47:25 -0800 Subject: [PATCH 23/62] Fix multiplication order --- src/rulesets/LinearAlgebra/symmetric.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index bfa7397ee..a4403b5c3 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -120,7 +120,7 @@ function eigen_rev!(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) ∂A = similar(A, eltype(U)) tmp = ∂U if ∂U isa AbstractZero - mul!(∂A.data, U, (U .* real.(∂λ))') + mul!(∂A.data, U, real.(∂λ) .* U') else _eigen_norm_phase_rev!(∂U, A, U) ∂K = mul!(∂A.data, U', ∂U) @@ -188,7 +188,7 @@ function rrule( function eigvals_pullback(Δλ) U = F.vectors ∂A = similar(A) - mul!(∂A.data, U, (U .* real.(Δλ))') + mul!(∂A.data, U, real.(Δλ) .* U') return NO_FIELDS, ∂A end eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ) From ce158962e2234921fbc3bd11dc03adfb7fbafe80 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 8 Dec 2020 02:47:42 -0800 Subject: [PATCH 24/62] Remove ambiguity in signature --- src/rulesets/LinearAlgebra/symmetric.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index a4403b5c3..bdd2f39c9 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -134,7 +134,7 @@ function eigen_rev!(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) end _eigen_norm_phase_fwd!(∂V, ::Union{Symmetric{T,S},Hermitian{T,S}}, V) where {T<:Real,S} = ∂V -function _eigen_norm_phase_fwd!(∂V, A::Hermitian, V) +function _eigen_norm_phase_fwd!(∂V, A::Hermitian{<:Complex}, V) k = A.uplo === 'U' ? size(A, 1) : 1 @inbounds for i in axes(V, 2) v = @view V[:, i] @@ -145,7 +145,7 @@ function _eigen_norm_phase_fwd!(∂V, A::Hermitian, V) end _eigen_norm_phase_rev!(∂V, ::Union{Symmetric{T,S},Hermitian{T,S}}, V) where {T<:Real,S} = ∂V -function _eigen_norm_phase_rev!(∂V, A::Hermitian, V) +function _eigen_norm_phase_rev!(∂V, A::Hermitian{<:Complex}, V) k = A.uplo === 'U' ? size(A, 1) : 1 @inbounds for i in axes(V, 2) v, ∂v = @views V[:, i], ∂V[:, i] From 9a695e3dd1a8d8aaf55abb716f714131d9626a67 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 8 Dec 2020 02:48:02 -0800 Subject: [PATCH 25/62] Define missing variable --- src/rulesets/LinearAlgebra/symmetric.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index bdd2f39c9..b16bc6f7d 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -137,8 +137,8 @@ _eigen_norm_phase_fwd!(∂V, ::Union{Symmetric{T,S},Hermitian{T,S}}, V) where {T function _eigen_norm_phase_fwd!(∂V, A::Hermitian{<:Complex}, V) k = A.uplo === 'U' ? size(A, 1) : 1 @inbounds for i in axes(V, 2) - v = @view V[:, i] - vₖ, ∂vₖ = real(v[k]), ∂V[k, i] + v, ∂v = @views V[:, i], ∂V[:, i] + vₖ, ∂vₖ = real(v[k]), ∂v[k] ∂v .-= v .* (imag(∂vₖ) / ifelse(iszero(vₖ), one(vₖ), vₖ)) end return ∂V From 6d25d6f0b6040796d25b7596cef9c7a2d4b616ef Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 8 Dec 2020 03:53:26 -0800 Subject: [PATCH 26/62] Make pure imaginary --- src/rulesets/LinearAlgebra/symmetric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index b16bc6f7d..b297a8cda 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -139,7 +139,7 @@ function _eigen_norm_phase_fwd!(∂V, A::Hermitian{<:Complex}, V) @inbounds for i in axes(V, 2) v, ∂v = @views V[:, i], ∂V[:, i] vₖ, ∂vₖ = real(v[k]), ∂v[k] - ∂v .-= v .* (imag(∂vₖ) / ifelse(iszero(vₖ), one(vₖ), vₖ)) + ∂v .-= v .* (im * (imag(∂vₖ) / ifelse(iszero(vₖ), one(vₖ), vₖ))) end return ∂V end From 5a852060b3b411775cbaf20c905cf30983c9a46c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 8 Dec 2020 03:53:48 -0800 Subject: [PATCH 27/62] Add tests for eigendecomposition rules --- test/rulesets/LinearAlgebra/symmetric.jl | 108 +++++++++++++++++++++++ 1 file changed, 108 insertions(+) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index f76c28836..0ba94e5d4 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -41,4 +41,112 @@ frule_test(f, (x, SymHerm(Δx, uplo))) rrule_test(f, ΔΩ, (x, ∂x)) end + + # symmetric/hermitian eigendecomposition follows the sign convention + # v = v * sign(real(vₖ)) * sign(vₖ)', where vₖ is the first or last coordinate + # in the eigenvector. This is unstable for finite differences, but using the convention + # v = v * sign(vₖ)' seems to be more stable, the (co)tangents are related as + # ∂v_ad = sign(real(vₖ)) * ∂v_fd + + function _eigvecs_stabilize_mat(vectors, uplo) + Ui = Symbol(uplo) === :U ? @view(vectors[end, :]) : @view(vectors[1, :]) + return Diagonal(conj.(sign.(Ui))) + end + + function _eigen_stable(A) + F = eigen(A) + rmul!(F.vectors, _eigvecs_stabilize_mat(F.vectors, A.uplo)) + return F + end + + @testset "eigendecomposition" begin + @testset "eigen/eigen!" begin + # avoid implementing to_vec(::Eigen) + asnt(E::Eigen) = (values=E.values, vectors=E.vectors) + + n = 10 + @testset "eigen!(::Hermitian{ComplexF64}) frule" for SymHerm in + (Symmetric, Hermitian), + T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), + uplo in (:L, :U) + + A, ΔA = SymHerm(randn(T, n, n), uplo), SymHerm(randn(T, n, n), uplo) + F = eigen!(copy(A)) + F_ad, ∂F_ad = frule((Zero(), copy(ΔA)), eigen!, copy(A)) + @test F_ad == F + @test ∂F_ad isa Composite{typeof(F)} + f = x -> asnt(eigen(SymHerm(x, uplo))) + f_stable = x -> asnt(_eigen_stable(SymHerm(x, uplo))) + ∂F_fd = jvp(_fdm, f, (A.data, ΔA.data)) + @test ∂F_ad.values ≈ ∂F_fd.values + F_stable = f_stable(A) + ∂F_stable_fd = jvp(_fdm, f_stable, (A.data, ΔA.data)) + C = _eigvecs_stabilize_mat(F.vectors, uplo) + @test ∂F_ad.vectors * C ≈ ∂F_stable_fd.vectors + end + + @testset "eigen(::Hermitian{ComplexF64}) rrule" for SymHerm in + (Symmetric, Hermitian), + T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), + uplo in (:L, :U) + + A, ΔU, Δλ = randn(T, n, n), randn(T, n, n), randn(n) + symA = SymHerm(A, uplo) + F = eigen(symA) + ΔF = Composite{typeof(F)}(; values=Δλ, vectors=ΔU) + F_ad, back = rrule(eigen, symA) + @test F_ad == F + ∂self, ∂symA = back(ΔF) + @test ∂self === NO_FIELDS + ∂symA = unthunk(∂symA) + @test ∂symA isa typeof(symA) + @test ∂symA.uplo == symA.uplo + # pull the cotangent back to A to test against finite differences + ∂A = unthunk(rrule(SymHerm, A, uplo)[2](∂symA)[2]) + # adopt a deterministic sign convention to stabilize FD + C = _eigvecs_stabilize_mat(F.vectors, uplo) + ΔF_stable = Composite{typeof(F)}(; values=Δλ, vectors=ΔU * C) + f = x -> asnt(_eigen_stable(SymHerm(x, uplo))) + ∂A_fd = j′vp(_fdm, f, ΔF_stable, A)[1] + @test ∂A ≈ ∂A_fd + end + end + + @testset "eigvals!/eigvals" begin + n = 10 + @testset "eigvals!(::Hermitian{ComplexF64}) frule" for SymHerm in + (Symmetric, Hermitian), + T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), + uplo in (:L, :U) + + A, ΔA = SymHerm(randn(T, n, n), uplo), SymHerm(randn(T, n, n), uplo) + λ = eigvals!(copy(A)) + λ_ad, ∂λ_ad = frule((Zero(), ΔA), eigvals!, copy(A)) + @test λ_ad ≈ λ # inexact because frule uses eigen not eigvals + ∂λ_ad = unthunk(∂λ_ad) + @test ∂λ_ad isa typeof(λ) + @test ∂λ_ad ≈ jvp(_fdm, A -> eigvals(SymHerm(A, uplo)), (A.data, ΔA.data)) + end + + @testset "eigvals(::Hermitian{ComplexF64}) rrule" for SymHerm in + (Symmetric, Hermitian), + T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), + uplo in (:L, :U) + + A, ΔU, Δλ = randn(T, n, n), randn(T, n, n), randn(n) + symA = SymHerm(A, uplo) + λ = eigvals(symA) + λ_ad, back = rrule(eigvals, symA) + @test λ_ad ≈ λ # inexact because rrule uses eigen not eigvals + ∂self, ∂symA = back(Δλ) + @test ∂self === NO_FIELDS + ∂symA = unthunk(∂symA) + @test ∂symA isa typeof(symA) + @test ∂symA.uplo == symA.uplo + # pull the cotangent back to A to test against finite differences + ∂A = unthunk(rrule(SymHerm, A, uplo)[2](∂symA)[2]) + @test ∂A ≈ j′vp(_fdm, A -> eigvals(SymHerm(A, uplo)), Δλ, A)[1] + end + end + end end From a0895e44a354ef89fbe1f71b3932637f0c1ff99b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 00:39:19 -0800 Subject: [PATCH 28/62] Test from nonsymmetric matrix --- test/rulesets/LinearAlgebra/symmetric.jl | 28 +++++++++++++++--------- 1 file changed, 18 insertions(+), 10 deletions(-) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 0ba94e5d4..52b314c82 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -70,17 +70,22 @@ T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) - A, ΔA = SymHerm(randn(T, n, n), uplo), SymHerm(randn(T, n, n), uplo) - F = eigen!(copy(A)) - F_ad, ∂F_ad = frule((Zero(), copy(ΔA)), eigen!, copy(A)) + A, ΔA, ΔU, Δλ = randn(T, n, n), randn(T, n, n), randn(T, n, n), randn(n) + symA = SymHerm(A, uplo) + ΔsymA = frule((Zero(), ΔA, Zero()), SymHerm, A, uplo)[2] + + F = eigen!(copy(symA)) + F_ad, ∂F_ad = frule((Zero(), copy(ΔsymA)), eigen!, copy(symA)) @test F_ad == F @test ∂F_ad isa Composite{typeof(F)} + @test ∂F_ad.values isa typeof(F.values) + @test ∂F_ad.vectors isa typeof(F.vectors) f = x -> asnt(eigen(SymHerm(x, uplo))) - f_stable = x -> asnt(_eigen_stable(SymHerm(x, uplo))) - ∂F_fd = jvp(_fdm, f, (A.data, ΔA.data)) + ∂F_fd = jvp(_fdm, f, (A, ΔA)) @test ∂F_ad.values ≈ ∂F_fd.values + f_stable = x -> asnt(_eigen_stable(SymHerm(x, uplo))) F_stable = f_stable(A) - ∂F_stable_fd = jvp(_fdm, f_stable, (A.data, ΔA.data)) + ∂F_stable_fd = jvp(_fdm, f_stable, (A, ΔA)) C = _eigvecs_stabilize_mat(F.vectors, uplo) @test ∂F_ad.vectors * C ≈ ∂F_stable_fd.vectors end @@ -119,13 +124,16 @@ T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) - A, ΔA = SymHerm(randn(T, n, n), uplo), SymHerm(randn(T, n, n), uplo) - λ = eigvals!(copy(A)) - λ_ad, ∂λ_ad = frule((Zero(), ΔA), eigvals!, copy(A)) + A, ΔA, ΔU, Δλ = randn(T, n, n), randn(T, n, n), randn(T, n, n), randn(n) + symA = SymHerm(A, uplo) + ΔsymA = frule((Zero(), ΔA, Zero()), SymHerm, A, uplo)[2] + + λ = eigvals!(copy(symA)) + λ_ad, ∂λ_ad = frule((Zero(), copy(ΔsymA)), eigvals!, copy(symA)) @test λ_ad ≈ λ # inexact because frule uses eigen not eigvals ∂λ_ad = unthunk(∂λ_ad) @test ∂λ_ad isa typeof(λ) - @test ∂λ_ad ≈ jvp(_fdm, A -> eigvals(SymHerm(A, uplo)), (A.data, ΔA.data)) + @test ∂λ_ad ≈ jvp(_fdm, A -> eigvals(SymHerm(A, uplo)), (A, ΔA)) end @testset "eigvals(::Hermitian{ComplexF64}) rrule" for SymHerm in From 3f5470aca814127dad5c2803fcf43f27d515cab8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 00:39:32 -0800 Subject: [PATCH 29/62] Add newlines --- test/rulesets/LinearAlgebra/symmetric.jl | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 52b314c82..5042280d5 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -80,9 +80,11 @@ @test ∂F_ad isa Composite{typeof(F)} @test ∂F_ad.values isa typeof(F.values) @test ∂F_ad.vectors isa typeof(F.vectors) + f = x -> asnt(eigen(SymHerm(x, uplo))) ∂F_fd = jvp(_fdm, f, (A, ΔA)) @test ∂F_ad.values ≈ ∂F_fd.values + f_stable = x -> asnt(_eigen_stable(SymHerm(x, uplo))) F_stable = f_stable(A) ∂F_stable_fd = jvp(_fdm, f_stable, (A, ΔA)) @@ -97,6 +99,7 @@ A, ΔU, Δλ = randn(T, n, n), randn(T, n, n), randn(n) symA = SymHerm(A, uplo) + F = eigen(symA) ΔF = Composite{typeof(F)}(; values=Δλ, vectors=ΔU) F_ad, back = rrule(eigen, symA) @@ -106,6 +109,7 @@ ∂symA = unthunk(∂symA) @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo + # pull the cotangent back to A to test against finite differences ∂A = unthunk(rrule(SymHerm, A, uplo)[2](∂symA)[2]) # adopt a deterministic sign convention to stabilize FD @@ -143,6 +147,7 @@ A, ΔU, Δλ = randn(T, n, n), randn(T, n, n), randn(n) symA = SymHerm(A, uplo) + λ = eigvals(symA) λ_ad, back = rrule(eigvals, symA) @test λ_ad ≈ λ # inexact because rrule uses eigen not eigvals @@ -151,6 +156,7 @@ ∂symA = unthunk(∂symA) @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo + # pull the cotangent back to A to test against finite differences ∂A = unthunk(rrule(SymHerm, A, uplo)[2](∂symA)[2]) @test ∂A ≈ j′vp(_fdm, A -> eigvals(SymHerm(A, uplo)), Δλ, A)[1] From 84111f9a4b06ccbe9ac2c30ed45157d594f85153 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 00:39:48 -0800 Subject: [PATCH 30/62] Remove unnecessary unthunks --- test/rulesets/LinearAlgebra/symmetric.jl | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 5042280d5..c25ae2e65 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -135,7 +135,6 @@ λ = eigvals!(copy(symA)) λ_ad, ∂λ_ad = frule((Zero(), copy(ΔsymA)), eigvals!, copy(symA)) @test λ_ad ≈ λ # inexact because frule uses eigen not eigvals - ∂λ_ad = unthunk(∂λ_ad) @test ∂λ_ad isa typeof(λ) @test ∂λ_ad ≈ jvp(_fdm, A -> eigvals(SymHerm(A, uplo)), (A, ΔA)) end @@ -153,12 +152,11 @@ @test λ_ad ≈ λ # inexact because rrule uses eigen not eigvals ∂self, ∂symA = back(Δλ) @test ∂self === NO_FIELDS - ∂symA = unthunk(∂symA) @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo # pull the cotangent back to A to test against finite differences - ∂A = unthunk(rrule(SymHerm, A, uplo)[2](∂symA)[2]) + ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> eigvals(SymHerm(A, uplo)), Δλ, A)[1] end end From 5f9c27f61d48743ff4ee5b5ef36346134258c388 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 00:49:14 -0800 Subject: [PATCH 31/62] Test type-stability --- test/rulesets/LinearAlgebra/symmetric.jl | 16 +++++++--------- 1 file changed, 7 insertions(+), 9 deletions(-) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index c25ae2e65..85f055d61 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -75,7 +75,7 @@ ΔsymA = frule((Zero(), ΔA, Zero()), SymHerm, A, uplo)[2] F = eigen!(copy(symA)) - F_ad, ∂F_ad = frule((Zero(), copy(ΔsymA)), eigen!, copy(symA)) + F_ad, ∂F_ad = @inferred frule((Zero(), copy(ΔsymA)), eigen!, copy(symA)) @test F_ad == F @test ∂F_ad isa Composite{typeof(F)} @test ∂F_ad.values isa typeof(F.values) @@ -102,17 +102,15 @@ F = eigen(symA) ΔF = Composite{typeof(F)}(; values=Δλ, vectors=ΔU) - F_ad, back = rrule(eigen, symA) + F_ad, back = @inferred rrule(eigen, symA) @test F_ad == F - ∂self, ∂symA = back(ΔF) + ∂self, ∂symA = @inferred back(ΔF) @test ∂self === NO_FIELDS - ∂symA = unthunk(∂symA) @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo # pull the cotangent back to A to test against finite differences - ∂A = unthunk(rrule(SymHerm, A, uplo)[2](∂symA)[2]) - # adopt a deterministic sign convention to stabilize FD + ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] C = _eigvecs_stabilize_mat(F.vectors, uplo) ΔF_stable = Composite{typeof(F)}(; values=Δλ, vectors=ΔU * C) f = x -> asnt(_eigen_stable(SymHerm(x, uplo))) @@ -133,7 +131,7 @@ ΔsymA = frule((Zero(), ΔA, Zero()), SymHerm, A, uplo)[2] λ = eigvals!(copy(symA)) - λ_ad, ∂λ_ad = frule((Zero(), copy(ΔsymA)), eigvals!, copy(symA)) + λ_ad, ∂λ_ad = @inferred frule((Zero(), copy(ΔsymA)), eigvals!, copy(symA)) @test λ_ad ≈ λ # inexact because frule uses eigen not eigvals @test ∂λ_ad isa typeof(λ) @test ∂λ_ad ≈ jvp(_fdm, A -> eigvals(SymHerm(A, uplo)), (A, ΔA)) @@ -148,9 +146,9 @@ symA = SymHerm(A, uplo) λ = eigvals(symA) - λ_ad, back = rrule(eigvals, symA) + λ_ad, back = @inferred rrule(eigvals, symA) @test λ_ad ≈ λ # inexact because rrule uses eigen not eigvals - ∂self, ∂symA = back(Δλ) + ∂self, ∂symA = @inferred back(Δλ) @test ∂self === NO_FIELDS @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo From a846bec0018dffeaca8f9e77a0d61a134f7ae06c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 02:03:13 -0800 Subject: [PATCH 32/62] Test mixtures of Zeros --- src/rulesets/LinearAlgebra/symmetric.jl | 5 ++-- test/rulesets/LinearAlgebra/symmetric.jl | 38 +++++++++++++++++------- 2 files changed, 30 insertions(+), 13 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index b297a8cda..c91c68f51 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -107,7 +107,8 @@ function rrule( function eigen_pullback(ΔF::Composite{<:Eigen}) λ, U = F.values, F.vectors Δλ, ΔU = ΔF.values, ΔF.vectors - ∂A = eigen_rev!(A, λ, U, Δλ, copy(ΔU)) + ΔU = ΔU isa AbstractZero ? ΔU : copy(ΔU) + ∂A = eigen_rev!(A, λ, U, Δλ, ΔU) return NO_FIELDS, ∂A end eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) @@ -116,7 +117,7 @@ end # ∂U is overwritten if not an `AbstractZero` function eigen_rev!(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) - ∂λ isa AbstractZero && ∂λ isa AbstractZero && return (NO_FIELDS, ∂λ + ∂U) + ∂λ isa AbstractZero && ∂U isa AbstractZero && return ∂λ + ∂U ∂A = similar(A, eltype(U)) tmp = ∂U if ∂U isa AbstractZero diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 85f055d61..468d57092 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -75,6 +75,7 @@ ΔsymA = frule((Zero(), ΔA, Zero()), SymHerm, A, uplo)[2] F = eigen!(copy(symA)) + @test @inferred(frule((Zero(), Zero()), eigen!, copy(symA))) == (F, Zero()) F_ad, ∂F_ad = @inferred frule((Zero(), copy(ΔsymA)), eigen!, copy(symA)) @test F_ad == F @test ∂F_ad isa Composite{typeof(F)} @@ -104,18 +105,31 @@ ΔF = Composite{typeof(F)}(; values=Δλ, vectors=ΔU) F_ad, back = @inferred rrule(eigen, symA) @test F_ad == F - ∂self, ∂symA = @inferred back(ΔF) - @test ∂self === NO_FIELDS - @test ∂symA isa typeof(symA) - @test ∂symA.uplo == symA.uplo - # pull the cotangent back to A to test against finite differences - ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] C = _eigvecs_stabilize_mat(F.vectors, uplo) - ΔF_stable = Composite{typeof(F)}(; values=Δλ, vectors=ΔU * C) - f = x -> asnt(_eigen_stable(SymHerm(x, uplo))) - ∂A_fd = j′vp(_fdm, f, ΔF_stable, A)[1] - @test ∂A ≈ ∂A_fd + CT = Composite{typeof(F)} + + @testset for nzprops in ([:values], [:vectors], [:values, :vectors]) + ∂F = CT(; [s => getproperty(ΔF, s) for s in nzprops]...) + ∂F_stable = CT(; [s => copy(getproperty(ΔF, s)) for s in nzprops]...) + :vectors in nzprops && rmul!(∂F_stable.vectors, C) + + f_stable = function(x) + F_ = _eigen_stable(SymHerm(x, uplo)) + return (; (s => getproperty(F_, s) for s in nzprops)...) + end + + ∂self, ∂symA = @inferred back(∂F) + @test ∂self === NO_FIELDS + @test ∂symA isa typeof(symA) + @test ∂symA.uplo == symA.uplo + ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] + ∂A_fd = j′vp(_fdm, f_stable, ∂F_stable, A)[1] + @test ∂A ≈ ∂A_fd + end + + @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) + @test @inferred(back(CT())) == (NO_FIELDS, Zero()) end end @@ -128,7 +142,7 @@ A, ΔA, ΔU, Δλ = randn(T, n, n), randn(T, n, n), randn(T, n, n), randn(n) symA = SymHerm(A, uplo) - ΔsymA = frule((Zero(), ΔA, Zero()), SymHerm, A, uplo)[2] + ΔsymA = @inferred frule((Zero(), ΔA, Zero()), SymHerm, A, uplo)[2] λ = eigvals!(copy(symA)) λ_ad, ∂λ_ad = @inferred frule((Zero(), copy(ΔsymA)), eigvals!, copy(symA)) @@ -156,6 +170,8 @@ # pull the cotangent back to A to test against finite differences ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] @test ∂A ≈ j′vp(_fdm, A -> eigvals(SymHerm(A, uplo)), Δλ, A)[1] + + @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) end end end From aeb3295ea5139034caa96ccb2cb1a945772f3c1c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 02:07:28 -0800 Subject: [PATCH 33/62] Use more informative testset names --- test/rulesets/LinearAlgebra/symmetric.jl | 12 ++++-------- 1 file changed, 4 insertions(+), 8 deletions(-) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 468d57092..d6d65f3b0 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -65,8 +65,7 @@ asnt(E::Eigen) = (values=E.values, vectors=E.vectors) n = 10 - @testset "eigen!(::Hermitian{ComplexF64}) frule" for SymHerm in - (Symmetric, Hermitian), + @testset "eigen!(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) @@ -93,8 +92,7 @@ @test ∂F_ad.vectors * C ≈ ∂F_stable_fd.vectors end - @testset "eigen(::Hermitian{ComplexF64}) rrule" for SymHerm in - (Symmetric, Hermitian), + @testset "eigen(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) @@ -135,8 +133,7 @@ @testset "eigvals!/eigvals" begin n = 10 - @testset "eigvals!(::Hermitian{ComplexF64}) frule" for SymHerm in - (Symmetric, Hermitian), + @testset "eigvals!(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) @@ -151,8 +148,7 @@ @test ∂λ_ad ≈ jvp(_fdm, A -> eigvals(SymHerm(A, uplo)), (A, ΔA)) end - @testset "eigvals(::Hermitian{ComplexF64}) rrule" for SymHerm in - (Symmetric, Hermitian), + @testset "eigvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) From 925c9c38fd7ce5c6b7dc85d7bf0c19e84476231e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 03:03:27 -0800 Subject: [PATCH 34/62] Fix svd pullback bugs --- src/rulesets/LinearAlgebra/symmetric.jl | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index c91c68f51..5cecf41a3 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -203,11 +203,15 @@ end function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}) F = svd(A) function svd_pullback(ΔF::Composite{<:SVD}) - U, Vt = F.U, F.Vt - c = _svd_eigvals_sign!(similar(F.S), U, Vt) + U, V = F.U, F.V + c = _svd_eigvals_sign!(similar(F.S), U, V) λ = F.S .* c - ∂λ = ΔF.S .* c - ∂U = ΔF.U .+ (ΔF.Vt .+ ΔF.V') .* c' + ∂λ = ΔF.S isa AbstractZero ? ΔF.S : ΔF.S .* c + if all(x -> x isa AbstractZero, (ΔF.U, ΔF.V, ΔF.Vt)) + ∂U = ΔF.U + ΔF.V + ΔF.Vt + else + ∂U = ΔF.U .+ (ΔF.V .+ ΔF.Vt') .* c' + end ∂A = eigen_rev!(A, λ, U, ∂λ, ∂U) return NO_FIELDS, ∂A end @@ -216,14 +220,14 @@ function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.Bla end # given singular vectors, compute sign of eigenvalues corresponding to singular values -function _svd_eigvals_sign!(c, U, Vt) +function _svd_eigvals_sign!(c, U, V) n = size(U, 1) @inbounds broadcast!(c, eachindex(c)) do i u = @views U[:, i] # find element not close to zero # at least one element has abs2 ≥ 1/n > 1/(n + 1) k = findfirst(x -> (n + 1) * abs2(x) ≥ 1, u) - return sign(real(u[k]) * real(Vt[k, i])) + return sign(real(u[k]) * real(V[k, i])) end return c end From 184ba2c78ba39bf84f14a404893762407c144ee9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 03:04:27 -0800 Subject: [PATCH 35/62] Add svd pullback tests --- test/rulesets/LinearAlgebra/symmetric.jl | 69 +++++++++++++++++++++--- 1 file changed, 62 insertions(+), 7 deletions(-) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index d6d65f3b0..eb0b52c0e 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -53,17 +53,17 @@ return Diagonal(conj.(sign.(Ui))) end - function _eigen_stable(A) - F = eigen(A) - rmul!(F.vectors, _eigvecs_stabilize_mat(F.vectors, A.uplo)) - return F - end - @testset "eigendecomposition" begin @testset "eigen/eigen!" begin # avoid implementing to_vec(::Eigen) asnt(E::Eigen) = (values=E.values, vectors=E.vectors) + function _eigen_stable(A) + F = eigen(A) + rmul!(F.vectors, _eigvecs_stabilize_mat(F.vectors, A.uplo)) + return F + end + n = 10 @testset "eigen!(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), @@ -109,7 +109,7 @@ @testset for nzprops in ([:values], [:vectors], [:values, :vectors]) ∂F = CT(; [s => getproperty(ΔF, s) for s in nzprops]...) - ∂F_stable = CT(; [s => copy(getproperty(ΔF, s)) for s in nzprops]...) + ∂F_stable = (; [s => copy(getproperty(ΔF, s)) for s in nzprops]...) :vectors in nzprops && rmul!(∂F_stable.vectors, C) f_stable = function(x) @@ -171,4 +171,59 @@ end end end + + @testset "singular value decomposition" begin + # avoid implementing to_vec(::Eigen) + asnt(F::SVD) = (U=F.U, S=F.S, V=F.V, Vt=F.Vt) + + function _svd_stable(A) + F = svd(A) + C = _eigvecs_stabilize_mat(F.U, A.uplo) + rmul!(F.U, C) + lmul!(C, F.Vt) + return F + end + + n = 10 + @testset "svd(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), + T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), + uplo in (:L, :U) + + A, ΔU, ΔV, ΔVt = ntuple(_ -> randn(T, n, n), 4) + ΔS = randn(n) + symA = SymHerm(A, uplo) + + F = svd(symA) + CT = Composite{typeof(F)} + ΔF = CT(; U=ΔU, V=ΔV, Vt=ΔVt, S=ΔS) + F_ad, back = @inferred rrule(svd, symA) + @test F_ad == F + + C = _eigvecs_stabilize_mat(F.U, uplo) + + @testset for nzprops in ([:U], [:S], [:V], [:Vt], [:U, :S, :V, :Vt]) + ∂F = CT(; [s => getproperty(ΔF, s) for s in nzprops]...) + ∂F_stable = (; [s => copy(getproperty(ΔF, s)) for s in nzprops]...) + :U in nzprops && rmul!(∂F_stable.U, C) + :Vt in nzprops && lmul!(C, ∂F_stable.Vt) + :V in nzprops && rmul!(∂F_stable.V, C) + + f_stable = function(x) + F_ = _svd_stable(SymHerm(x, uplo)) + return (; (s => getproperty(F_, s) for s in nzprops)...) + end + + ∂self, ∂symA = @inferred back(∂F) + @test ∂self === NO_FIELDS + @test ∂symA isa typeof(symA) + @test ∂symA.uplo == symA.uplo + ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] + ∂A_fd = j′vp(_fdm, f_stable, ∂F_stable, A)[1] + @test ∂A ≈ ∂A_fd + end + + @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) + @test @inferred(back(CT())) == (NO_FIELDS, Zero()) + end + end end From f4c479416dfd223088da65167ad7a7af9fa7108e Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 03:11:29 -0800 Subject: [PATCH 36/62] Return correct argument --- src/rulesets/LinearAlgebra/symmetric.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 5cecf41a3..84e70884f 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -242,8 +242,8 @@ function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS S = abs.(λ) function svdvals_pullback(ΔS) ∂λ = ΔS .* S ./ ifelse.(iszero.(λ), one.(λ), λ) - ∂A = back(∂λ) - return NO_FIELDS, ∂A + _, ∂A = back(∂λ) + return NO_FIELDS, unthunk(∂A) end svdvals_pullback(ΔS::AbstractZero) = (NO_FIELDS, ΔS) return S, svdvals_pullback From c64b321ea5981f04d07374f6a7f36f04c43101fd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 03:12:26 -0800 Subject: [PATCH 37/62] Remove unused (co)tangents --- test/rulesets/LinearAlgebra/symmetric.jl | 6 +++--- 1 file changed, 3 insertions(+), 3 deletions(-) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index eb0b52c0e..1338c807c 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -69,7 +69,7 @@ T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) - A, ΔA, ΔU, Δλ = randn(T, n, n), randn(T, n, n), randn(T, n, n), randn(n) + A, ΔA = randn(T, n, n), randn(T, n, n) symA = SymHerm(A, uplo) ΔsymA = frule((Zero(), ΔA, Zero()), SymHerm, A, uplo)[2] @@ -137,7 +137,7 @@ T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) - A, ΔA, ΔU, Δλ = randn(T, n, n), randn(T, n, n), randn(T, n, n), randn(n) + A, ΔA = randn(T, n, n), randn(T, n, n) symA = SymHerm(A, uplo) ΔsymA = @inferred frule((Zero(), ΔA, Zero()), SymHerm, A, uplo)[2] @@ -152,7 +152,7 @@ T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) - A, ΔU, Δλ = randn(T, n, n), randn(T, n, n), randn(n) + A, Δλ = randn(T, n, n), randn(n) symA = SymHerm(A, uplo) λ = eigvals(symA) From 8f514310c5d5b524b841a0009dd9505cc4a58fe6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 03:12:37 -0800 Subject: [PATCH 38/62] Add svdvals tests --- test/rulesets/LinearAlgebra/symmetric.jl | 22 ++++++++++++++++++++++ 1 file changed, 22 insertions(+) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 1338c807c..1ac568677 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -225,5 +225,27 @@ @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) @test @inferred(back(CT())) == (NO_FIELDS, Zero()) end + + @testset "svdvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), + T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), + uplo in (:L, :U) + + A, ΔS = randn(T, n, n), randn(n) + symA = SymHerm(A, uplo) + + S = svdvals(symA) + S_ad, back = @inferred rrule(svdvals, symA) + @test S_ad ≈ S # inexact because rrule uses svd not svdvals + ∂self, ∂symA = @inferred back(ΔS) + @test ∂self === NO_FIELDS + @test ∂symA isa typeof(symA) + @test ∂symA.uplo == symA.uplo + + # pull the cotangent back to A to test against finite differences + ∂A = rrule(SymHerm, A, uplo)[2](∂symA)[2] + @test ∂A ≈ j′vp(_fdm, A -> svdvals(SymHerm(A, uplo)), ΔS, A)[1] + + @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) + end end end From f27d33303f485a97bf26281f8dda639b3f48f40a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 03:37:18 -0800 Subject: [PATCH 39/62] Fix typo --- test/rulesets/LinearAlgebra/symmetric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 1ac568677..4cdf1b35f 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -173,7 +173,7 @@ end @testset "singular value decomposition" begin - # avoid implementing to_vec(::Eigen) + # avoid implementing to_vec(::SVD) asnt(F::SVD) = (U=F.U, S=F.S, V=F.V, Vt=F.Vt) function _svd_stable(A) From 35d8c3169310f2a23572337fd1e758d06e363af6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 03:37:43 -0800 Subject: [PATCH 40/62] Restrict SVD test to greater than v1.3.0 --- test/rulesets/LinearAlgebra/symmetric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 4cdf1b35f..8c13ae0d8 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -185,7 +185,7 @@ end n = 10 - @testset "svd(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), + VERSION ≥ v"1.3.0" && @testset "svd(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) From e39108cde90a04c617bfc5cf0b16cffdc591a1d3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 04:10:35 -0800 Subject: [PATCH 41/62] Only check type-stability on 1.6 --- test/rulesets/LinearAlgebra/symmetric.jl | 3 ++- 1 file changed, 2 insertions(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 8c13ae0d8..34471a73f 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -213,7 +213,8 @@ return (; (s => getproperty(F_, s) for s in nzprops)...) end - ∂self, ∂symA = @inferred back(∂F) + VERSION ≥ v"1.6.0-DEV.1686" && @inferred back(∂F) + ∂self, ∂symA = back(∂F) @test ∂self === NO_FIELDS @test ∂symA isa typeof(symA) @test ∂symA.uplo == symA.uplo From 3b30638a6df370c23f5b023de3b82d52e0182108 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 04:57:40 -0800 Subject: [PATCH 42/62] Avoid specifying sortby keyword This is not defined on earlier Julia versions --- src/rulesets/LinearAlgebra/symmetric.jl | 26 ++++++++++++++----------- 1 file changed, 15 insertions(+), 11 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 84e70884f..bee79e4f0 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -81,9 +81,9 @@ function frule( (_, ΔA), ::typeof(eigen!), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}; - sortby::Union{Function,Nothing}=nothing, + kwargs..., ) - F = eigen!(A; sortby=sortby) + F = eigen!(A; kwargs...) ΔA isa AbstractZero && return F, ΔA λ, U = F.values, F.vectors tmp = U' * ΔA @@ -101,9 +101,9 @@ end function rrule( ::typeof(eigen), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}; - sortby::Union{Function,Nothing}=nothing, + kwargs..., ) - F = eigen(A; sortby=sortby) + F = eigen(A; kwargs...) function eigen_pullback(ΔF::Composite{<:Eigen}) λ, U = F.values, F.vectors Δλ, ΔU = ΔF.values, ΔF.vectors @@ -165,10 +165,10 @@ function frule( (_, ΔA), ::typeof(eigvals!), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}; - sortby::Union{Function,Nothing}=nothing, + kwargs..., ) - ΔA isa AbstractZero && return eigvals!(A; sortby=sortby), ΔA - F = eigen!(A; sortby=sortby) + ΔA isa AbstractZero && return eigvals!(A; kwargs...), ΔA + F = eigen!(A; kwargs...) λ, U = F.values, F.vectors tmp = ΔA * U # diag(U' * tmp) without computing matrix product @@ -182,9 +182,9 @@ end function rrule( ::typeof(eigvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}; - sortby::Union{Function,Nothing}=nothing, + kwargs..., ) - F = eigen(A; sortby=sortby) + F = eigen(A; kwargs...) λ = F.values function eigvals_pullback(Δλ) U = F.vectors @@ -238,10 +238,14 @@ end function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}) # sorting doesn't affect the eigvals pullback, and it simplifies this rrule - λ, back = rrule(eigvals, A; sortby = x -> -abs2(x)) + λ, back = rrule(eigvals, A) S = abs.(λ) + p = sortperm(S; rev=true) + permute!(S, p) function svdvals_pullback(ΔS) - ∂λ = ΔS .* S ./ ifelse.(iszero.(λ), one.(λ), λ) + ∂λ = real.(ΔS) + invpermute!(∂λ, p) + ∂λ .*= sign.(λ) _, ∂A = back(∂λ) return NO_FIELDS, unthunk(∂A) end From f1e655915b9c56246e9d26ea3f73f8c6483ac5cf Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 05:01:35 -0800 Subject: [PATCH 43/62] Remove obsolete comment --- src/rulesets/LinearAlgebra/symmetric.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index bee79e4f0..a3da539de 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -237,7 +237,6 @@ end ##### function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}) - # sorting doesn't affect the eigvals pullback, and it simplifies this rrule λ, back = rrule(eigvals, A) S = abs.(λ) p = sortperm(S; rev=true) From 9c11b52fc5aa3a15dceab64640eb9b411970db65 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 14 Dec 2020 02:01:09 -0800 Subject: [PATCH 44/62] Abandon ship when derivatives explode --- src/rulesets/LinearAlgebra/symmetric.jl | 25 +++++++++++++++----- test/rulesets/LinearAlgebra/symmetric.jl | 29 ++++++++++++++++++++++++ 2 files changed, 48 insertions(+), 6 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index a3da539de..3e9534d66 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -134,13 +134,22 @@ function eigen_rev!(A::LinearAlgebra.RealHermSymComplexHerm, λ, U, ∂λ, ∂U) return ∂A end +# NOTE: for small vₖ, the derivative of sign(vₖ) explodes, causing the tangents to become +# unstable even for phase-invariant programs. So for small vₖ we don't account for the phase +# in the gradient. Then derivatives are accurate for phase-invariant programs but inaccurate +# for phase-dependent programs that have low vₖ. + _eigen_norm_phase_fwd!(∂V, ::Union{Symmetric{T,S},Hermitian{T,S}}, V) where {T<:Real,S} = ∂V function _eigen_norm_phase_fwd!(∂V, A::Hermitian{<:Complex}, V) k = A.uplo === 'U' ? size(A, 1) : 1 + ϵ = sqrt(eps(real(eltype(V)))) @inbounds for i in axes(V, 2) - v, ∂v = @views V[:, i], ∂V[:, i] - vₖ, ∂vₖ = real(v[k]), ∂v[k] - ∂v .-= v .* (im * (imag(∂vₖ) / ifelse(iszero(vₖ), one(vₖ), vₖ))) + v = @view V[:, i] + vₖ = real(v[k]) + if abs(vₖ) > ϵ + ∂v = @view ∂V[:, i] + ∂v .-= v .* (im * (imag(∂v[k]) / vₖ)) + end end return ∂V end @@ -148,11 +157,15 @@ end _eigen_norm_phase_rev!(∂V, ::Union{Symmetric{T,S},Hermitian{T,S}}, V) where {T<:Real,S} = ∂V function _eigen_norm_phase_rev!(∂V, A::Hermitian{<:Complex}, V) k = A.uplo === 'U' ? size(A, 1) : 1 + ϵ = sqrt(eps(real(eltype(V)))) @inbounds for i in axes(V, 2) - v, ∂v = @views V[:, i], ∂V[:, i] + v = @view V[:, i] vₖ = real(v[k]) - ∂c = dot(v, ∂v) - ∂v[k] -= im * (imag(∂c) / ifelse(iszero(vₖ), one(vₖ), vₖ)) + if abs(vₖ) > ϵ + ∂v = @view ∂V[:, i] + ∂c = dot(v, ∂v) + ∂v[k] -= im * (imag(∂c) / vₖ) + end end return ∂V end diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 34471a73f..691c5a26f 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -129,6 +129,35 @@ @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) @test @inferred(back(CT())) == (NO_FIELDS, Zero()) end + + @testset "phase convention from low value" begin + @testset for min_val in [0, eps(), sqrt(eps()), cbrt(eps()), eps()^(1//4)], + uplo in (:U, :L) + + U = randn(ComplexF64, n, n) + U[uplo === :U ? n : 1] = min_val + U = Matrix(qr(U).Q) + λ = sort(randn(n)) + A = Hermitian(U * Diagonal(λ) * U') + function f(A) + V = eigen(A).vectors + return V * V' + end + + Ȧ = Hermitian(randn(eltype(A), size(A))) + F, Ḟ_ad = frule((Zero(), copy(Ȧ)), eigen!, copy(A)) + V, V̇_ad = F.vectors, Ḟ_ad.vectors + Ω̇_ad = V̇_ad' * V + V' * V̇_ad + @test maximum(abs, Ω̇_ad) < sqrt(eps()) + + Ω̄ = randn(eltype(A), (n, n)) + V̄ = V * (Ω̄ + Ω̄') + F̄ = Composite{typeof(F)}(vectors = V̄) + _, back = rrule(eigen, A) + Ā = back(F̄)[2] + @test maximum(abs, Ā) < sqrt(eps()) + end + end end @testset "eigvals!/eigvals" begin From 443b7e6925ae83e2ca69c8349703427ad7cd657b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 14 Dec 2020 11:57:34 -0800 Subject: [PATCH 45/62] Handle Hermitian special-case for general eigen --- src/rulesets/LinearAlgebra/factorization.jl | 47 +++++++++++++++------ 1 file changed, 33 insertions(+), 14 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 00a439f04..d43ea33f1 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -76,8 +76,12 @@ end # - support degenerate matrices (see #144) function frule((_, ΔA), ::typeof(eigen!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} + if ΔA isa AbstractZero + F = eigen!(A; kwargs...) + return F, ΔA + end + ishermitian(A) && return frule((Zero(), Hermitian(ΔA)), eigen!, Hermitian(A); kwargs...) F = eigen!(A; kwargs...) - ΔA isa AbstractZero && return F, ΔA λ, V = F.values, F.vectors tmp = V \ ΔA ∂K = tmp * V @@ -92,21 +96,36 @@ function frule((_, ΔA), ::typeof(eigen!), A::StridedMatrix{T}; kwargs...) where end function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}} - F = eigen(A; kwargs...) + # NOTE: this check for hermitian-ness occurs in `eigen!`. We here do it in the rrule for + # `eigen` so that this works for non-mutating AD + isherm = ishermitian(A) + if isherm + hermA, back_Hermitian = rrule(Hermitian, A, :U) + F, back_eigen = rrule(eigen, hermA; kwargs...) + else + F = eigen(A; kwargs...) + end function eigen_pullback(ΔF::Composite{<:Eigen}) - λ, V = F.values, F.vectors - Δλ, ΔV = ΔF.values, ΔF.vectors - if ΔV isa AbstractZero - Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV) - ∂K = Diagonal(Δλ) - ∂A = V' \ ∂K * V' + if isherm + _, ∂hermA = back_eigen(ΔF) + ∂hermA isa AbstractZero && return (NO_FIELDS, ∂hermA) + _, ∂Atriu = back_Hermitian(∂hermA) + ∂A = triu!(∂Atriu.data) else - ∂V = copyto!(similar(ΔV), ΔV) - _eigen_norm_phase_rev!(∂V, A, V) - ∂K = V' * ∂V - ∂K ./= λ' .- conj.(λ) - ∂K[diagind(∂K)] .= Δλ - ∂A = mul!(∂K, V' \ ∂K, V') + λ, V = F.values, F.vectors + Δλ, ΔV = ΔF.values, ΔF.vectors + if ΔV isa AbstractZero + Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV) + ∂K = Diagonal(Δλ) + ∂A = V' \ ∂K * V' + else + ∂V = copyto!(similar(ΔV), ΔV) + _eigen_norm_phase_rev!(∂V, A, V) + ∂K = V' * ∂V + ∂K ./= λ' .- conj.(λ) + ∂K[diagind(∂K)] .= Δλ + ∂A = mul!(∂K, V' \ ∂K, V') + end end return NO_FIELDS, T <: Real ? real(∂A) : ∂A end From c2cb8597fa7ad89f7867f1e0a2e0c8ac43d5605b Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 14 Dec 2020 13:00:05 -0800 Subject: [PATCH 46/62] Fix comment --- src/rulesets/LinearAlgebra/symmetric.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 3e9534d66..1d23313df 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -238,7 +238,7 @@ function _svd_eigvals_sign!(c, U, V) @inbounds broadcast!(c, eachindex(c)) do i u = @views U[:, i] # find element not close to zero - # at least one element has abs2 ≥ 1/n > 1/(n + 1) + # at least one element satisfies abs2(x) ≥ 1/n > 1/(n + 1) k = findfirst(x -> (n + 1) * abs2(x) ≥ 1, u) return sign(real(u[k]) * real(V[k, i])) end From e1e3af70a368fcec77122c83d07f63ceaefc0775 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 14 Dec 2020 13:01:40 -0800 Subject: [PATCH 47/62] Resolve type-instability --- src/rulesets/LinearAlgebra/factorization.jl | 45 +++++++++------------ 1 file changed, 18 insertions(+), 27 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index d43ea33f1..7f653316f 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -96,36 +96,27 @@ function frule((_, ΔA), ::typeof(eigen!), A::StridedMatrix{T}; kwargs...) where end function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}} - # NOTE: this check for hermitian-ness occurs in `eigen!`. We here do it in the rrule for - # `eigen` so that this works for non-mutating AD - isherm = ishermitian(A) - if isherm - hermA, back_Hermitian = rrule(Hermitian, A, :U) - F, back_eigen = rrule(eigen, hermA; kwargs...) - else - F = eigen(A; kwargs...) - end + F = eigen(A; kwargs...) function eigen_pullback(ΔF::Composite{<:Eigen}) - if isherm - _, ∂hermA = back_eigen(ΔF) - ∂hermA isa AbstractZero && return (NO_FIELDS, ∂hermA) - _, ∂Atriu = back_Hermitian(∂hermA) + λ, V = F.values, F.vectors + Δλ, ΔV = ΔF.values, ΔF.vectors + ΔV isa AbstractZero && Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV) + if ishermitian(A) + hermA = Hermitian(A) + ∂V = copyto!(similar(ΔV), ΔV) + ∂hermA = eigen_rev!(hermA, λ, V, Δλ, ∂V) + ∂Atriu = _symherm_back(typeof(hermA), ∂hermA, hermA.uplo) ∂A = triu!(∂Atriu.data) + elseif ΔV isa AbstractZero + ∂K = Diagonal(Δλ) + ∂A = V' \ ∂K * V' else - λ, V = F.values, F.vectors - Δλ, ΔV = ΔF.values, ΔF.vectors - if ΔV isa AbstractZero - Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV) - ∂K = Diagonal(Δλ) - ∂A = V' \ ∂K * V' - else - ∂V = copyto!(similar(ΔV), ΔV) - _eigen_norm_phase_rev!(∂V, A, V) - ∂K = V' * ∂V - ∂K ./= λ' .- conj.(λ) - ∂K[diagind(∂K)] .= Δλ - ∂A = mul!(∂K, V' \ ∂K, V') - end + ∂V = copyto!(similar(ΔV), ΔV) + _eigen_norm_phase_rev!(∂V, A, V) + ∂K = V' * ∂V + ∂K ./= λ' .- conj.(λ) + ∂K[diagind(∂K)] .= Δλ + ∂A = mul!(∂K, V' \ ∂K, V') end return NO_FIELDS, T <: Real ? real(∂A) : ∂A end From 1246efd6440a3f5868093796c946c1d1fe40a721 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 14 Dec 2020 14:00:40 -0800 Subject: [PATCH 48/62] =?UTF-8?q?Handle=20when=20just=20=CE=94V=20is=20Zer?= =?UTF-8?q?o?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- src/rulesets/LinearAlgebra/factorization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 7f653316f..824ab3880 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -103,7 +103,7 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{ ΔV isa AbstractZero && Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV) if ishermitian(A) hermA = Hermitian(A) - ∂V = copyto!(similar(ΔV), ΔV) + ∂V = ΔV isa AbstractZero ? ΔV : copyto!(similar(ΔV), ΔV) ∂hermA = eigen_rev!(hermA, λ, V, Δλ, ∂V) ∂Atriu = _symherm_back(typeof(hermA), ∂hermA, hermA.uplo) ∂A = triu!(∂Atriu.data) From a9a896c0af825ecb8ff042c64b6c6837c2bc22a9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 14 Dec 2020 14:01:05 -0800 Subject: [PATCH 49/62] Test eigen for hermitian Matrix --- test/rulesets/LinearAlgebra/factorization.jl | 72 +++++++++++++++++++- 1 file changed, 69 insertions(+), 3 deletions(-) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 416f44169..bd4977ccb 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -90,7 +90,7 @@ end # NOTE: eigen!/eigen are not type-stable, so neither are their frule/rrule # avoid implementing to_vec(::Eigen) - f(E::Eigen) = (values=E.values, vectors=E.vectors) + asnt(E::Eigen) = (values=E.values, vectors=E.vectors) # NOTE: for unstructured matrices, low enough n, and this specific seed, finite # differences of eigen seems to be stable enough for direct comparison. @@ -105,7 +105,7 @@ end F_fwd, Ḟ_ad = frule((Zero(), copy(Ẋ)), eigen!, copy(X)) @test F_fwd == F @test Ḟ_ad isa Composite{typeof(F)} - Ḟ_fd = jvp(_fdm, f ∘ eigen! ∘ copy, (X, Ẋ)) + Ḟ_fd = jvp(_fdm, asnt ∘ eigen! ∘ copy, (X, Ẋ)) @test Ḟ_ad.values ≈ Ḟ_fd.values @test Ḟ_ad.vectors ≈ Ḟ_fd.vectors @test frule((Zero(), Zero()), eigen!, copy(X)) == (F, Zero()) @@ -136,7 +136,7 @@ end F̄ = CT(values = λ̄, vectors = V̄) s̄elf, X̄_ad = @inferred back(F̄) @test s̄elf === NO_FIELDS - X̄_fd = j′vp(_fdm, f ∘ eigen, F̄, X)[1] + X̄_fd = j′vp(_fdm, asnt ∘ eigen, F̄, X)[1] @test X̄_ad ≈ X̄_fd @test @inferred(back(Zero())) === (NO_FIELDS, Zero()) F̄zero = CT(values = Zero(), vectors = Zero()) @@ -175,6 +175,72 @@ end V̄proj2 = ChainRules._eigen_norm_phase_rev!(copy(V̄proj), X, F.vectors) @test V̄proj2 ≈ V̄proj end + + # below tests adapted from /test/rulesets/LinearAlgebra/symmetric.jl + @testset "hermitian matrices" begin + function _eigvecs_stabilize_mat(vectors) + Ui = @view(vectors[end, :]) + return Diagonal(conj.(sign.(Ui))) + end + + function _eigen_stable(A) + F = eigen(A) + rmul!(F.vectors, _eigvecs_stabilize_mat(F.vectors)) + return F + end + + n = 10 + @testset "eigen!(::Matrix{$T})" for T in (Float64, ComplexF64) + A, ΔA = Matrix(Hermitian(randn(T, n, n))), Matrix(Hermitian(randn(T, n, n))) + + F = eigen!(copy(A)) + @test frule((Zero(), Zero()), eigen!, copy(A)) == (F, Zero()) + F_ad, ∂F_ad = frule((Zero(), copy(ΔA)), eigen!, copy(A)) + @test F_ad == F + @test ∂F_ad isa Composite{typeof(F)} + @test ∂F_ad.values isa typeof(F.values) + @test ∂F_ad.vectors isa typeof(F.vectors) + + f = x -> asnt(eigen(Matrix(Hermitian(x)))) + ∂F_fd = jvp(_fdm, f, (A, ΔA)) + @test ∂F_ad.values ≈ ∂F_fd.values + + f_stable = x -> asnt(_eigen_stable(Matrix(Hermitian(x)))) + F_stable = f_stable(A) + ∂F_stable_fd = jvp(_fdm, f_stable, (A, ΔA)) + C = _eigvecs_stabilize_mat(F.vectors) + @test ∂F_ad.vectors * C ≈ ∂F_stable_fd.vectors + end + + @testset "eigen(::Matrix{$T})" for T in (Float64, ComplexF64) + A, ΔU, Δλ = Hermitian(randn(T, n, n)), randn(T, n, n), randn(n) + + F = eigen(A) + ΔF = Composite{typeof(F)}(; values=Δλ, vectors=ΔU) + F_ad, back = rrule(eigen, A) + @test F_ad == F + + C = _eigvecs_stabilize_mat(F.vectors) + CT = Composite{typeof(F)} + + @testset for nzprops in ([:values], [:vectors], [:values, :vectors]) + ∂F = CT(; [s => getproperty(ΔF, s) for s in nzprops]...) + ∂F_stable = (; [s => copy(getproperty(ΔF, s)) for s in nzprops]...) + :vectors in nzprops && rmul!(∂F_stable.vectors, C) + + f_stable = function(x) + F_ = _eigen_stable(Matrix(Hermitian(x))) + return (; (s => getproperty(F_, s) for s in nzprops)...) + end + + ∂self, ∂A = @inferred back(∂F) + @test ∂self === NO_FIELDS + @test ∂A isa typeof(A) + ∂A_fd = j′vp(_fdm, f_stable, ∂F_stable, A)[1] + @test ∂A ≈ ∂A_fd + end + end + end end @testset "eigvals/eigvals!" begin From 05a538379d39ac90661c15a0e9db400f49a64a38 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Mon, 14 Dec 2020 14:39:16 -0800 Subject: [PATCH 50/62] Make hermitian Matrix --- test/rulesets/LinearAlgebra/factorization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index bd4977ccb..4a0aa596b 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -213,7 +213,7 @@ end end @testset "eigen(::Matrix{$T})" for T in (Float64, ComplexF64) - A, ΔU, Δλ = Hermitian(randn(T, n, n)), randn(T, n, n), randn(n) + A, ΔU, Δλ = Matrix(Hermitian(randn(T, n, n))), randn(T, n, n), randn(n) F = eigen(A) ΔF = Composite{typeof(F)}(; values=Δλ, vectors=ΔU) From 41140cf3c04bfca3708784a0d77687de8c6f9f40 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 15 Dec 2020 23:48:46 -0800 Subject: [PATCH 51/62] Call eigen pullback from eigvals --- src/rulesets/LinearAlgebra/symmetric.jl | 8 +++----- 1 file changed, 3 insertions(+), 5 deletions(-) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 1d23313df..62c9b7bcb 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -197,15 +197,13 @@ function rrule( A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}; kwargs..., ) - F = eigen(A; kwargs...) + F, eigen_back = rrule(eigen, A; kwargs...) λ = F.values function eigvals_pullback(Δλ) - U = F.vectors - ∂A = similar(A) - mul!(∂A.data, U, real.(Δλ) .* U') + ∂F = Composite{typeof(F)}(values = Δλ) + _, ∂A = eigen_back(∂F) return NO_FIELDS, ∂A end - eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ) return λ, eigvals_pullback end From 1866acdb71354776480581c6e2eefa7d1cc680ba Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 15 Dec 2020 23:49:24 -0800 Subject: [PATCH 52/62] Only pass sortby to Hermitian eigen --- src/rulesets/LinearAlgebra/factorization.jl | 12 ++++++++---- 1 file changed, 8 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 824ab3880..50f779c34 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -76,11 +76,15 @@ end # - support degenerate matrices (see #144) function frule((_, ΔA), ::typeof(eigen!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} - if ΔA isa AbstractZero - F = eigen!(A; kwargs...) - return F, ΔA + ΔA isa AbstractZero && return (eigen!(A; kwargs...), ΔA) + if ishermitian(A) + sortby = get(kwargs, :sortby, nothing) + return if sortby === nothing + frule((Zero(), Hermitian(ΔA)), eigen!, Hermitian(A)) + else + frule((Zero(), Hermitian(ΔA)), eigen!, Hermitian(A); sortby=sortby) + end end - ishermitian(A) && return frule((Zero(), Hermitian(ΔA)), eigen!, Hermitian(A); kwargs...) F = eigen!(A; kwargs...) λ, V = F.values, F.vectors tmp = V \ ΔA From 6d777657c550df9f26a9fe4c2b487e78f8004e0c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 15 Dec 2020 23:49:47 -0800 Subject: [PATCH 53/62] Support Julia 1.0's return for _symherm_back --- src/rulesets/LinearAlgebra/factorization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 50f779c34..64b7d5364 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -110,7 +110,7 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{ ∂V = ΔV isa AbstractZero ? ΔV : copyto!(similar(ΔV), ΔV) ∂hermA = eigen_rev!(hermA, λ, V, Δλ, ∂V) ∂Atriu = _symherm_back(typeof(hermA), ∂hermA, hermA.uplo) - ∂A = triu!(∂Atriu.data) + ∂A = ∂Atriu isa AbstractTriangular ? triu!(∂Atriu.data) : ∂Atriu elseif ΔV isa AbstractZero ∂K = Diagonal(Δλ) ∂A = V' \ ∂K * V' From 1a211a4d43dcee0facb8f9fe69c6716517efe054 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 15 Dec 2020 23:50:53 -0800 Subject: [PATCH 54/62] Call Hermitian eigvals! frule --- src/rulesets/LinearAlgebra/factorization.jl | 32 +++++++++++++++------ 1 file changed, 24 insertions(+), 8 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 64b7d5364..dd4a1338b 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -187,15 +187,20 @@ end function frule((_, ΔA), ::typeof(eigvals!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} ΔA isa AbstractZero && return eigvals!(A; kwargs...), ΔA - F = eigen!(A; kwargs...) - λ, V = F.values, F.vectors - tmp = V \ ΔA - ∂λ = similar(λ) - # diag(tmp * V) without computing full matrix product - if eltype(∂λ) <: Real - broadcast!((a, b) -> sum(real ∘ prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V)) + if ishermitian(A) + λ, ∂λ = frule((Zero(), Hermitian(ΔA)), eigvals!, Hermitian(A)) + _sorteig!_fwd(∂λ, λ, get(kwargs, :sortby, nothing)) else - broadcast!((a, b) -> sum(prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V)) + F = eigen!(A; kwargs...) + λ, V = F.values, F.vectors + tmp = V \ ΔA + ∂λ = similar(λ) + # diag(tmp * V) without computing full matrix product + if eltype(∂λ) <: Real + broadcast!((a, b) -> sum(real ∘ prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V)) + else + broadcast!((a, b) -> sum(prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V)) + end end return λ, ∂λ end @@ -212,6 +217,17 @@ function rrule(::typeof(eigvals), A::StridedMatrix{T}; kwargs...) where {T<:Unio return λ, eigvals_pullback end +# adapted from LinearAlgebra.sorteig! +function _sorteig!_fwd(Δλ, λ, sortby) + Δλ isa AbstractZero && return (sort!(λ; by=sortby), Δλ) + if sortby !== nothing + p = sortperm(λ; alg=QuickSort, by=sortby) + permute!(λ, p) + permute!(Δλ, p) + end + return (λ, Δλ) +end + ##### ##### `cholesky` ##### From 21020f0f4eee76859f4587570ea5be3d217218f9 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 15 Dec 2020 23:51:10 -0800 Subject: [PATCH 55/62] Call eigen rrule in eigval rrule --- src/rulesets/LinearAlgebra/factorization.jl | 9 ++++----- 1 file changed, 4 insertions(+), 5 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index dd4a1338b..69e37ae4b 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -206,14 +206,13 @@ function frule((_, ΔA), ::typeof(eigvals!), A::StridedMatrix{T}; kwargs...) whe end function rrule(::typeof(eigvals), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}} - F = eigen(A; kwargs...) + F, eigen_back = rrule(eigen, A; kwargs...) λ = F.values function eigvals_pullback(Δλ) - V = F.vectors - ∂A = V' \ Diagonal(Δλ) * V' - return NO_FIELDS, T <: Real ? real(∂A) : ∂A + ∂F = Composite{typeof(F)}(values = Δλ) + _, ∂A = eigen_back(∂F) + return NO_FIELDS, ∂A end - eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ) return λ, eigvals_pullback end From 50bb4e59ea56a016cb5fb4d4b0b6183b7519cd03 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 15 Dec 2020 23:51:26 -0800 Subject: [PATCH 56/62] Test eigvals for hermitian Matrix-es --- test/rulesets/LinearAlgebra/factorization.jl | 25 ++++++++++++++++++++ 1 file changed, 25 insertions(+) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 4a0aa596b..dcb43aa25 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -280,6 +280,31 @@ end @test eltype(X̄) <: Real end end + + # below tests adapted from /test/rulesets/LinearAlgebra/symmetric.jl + @testset "hermitian matrices" begin + n = 10 + @testset "eigvals!(::Matrix{$T})" for T in (Float64, ComplexF64) + A, ΔA = Matrix(Hermitian(randn(T, n, n))), Matrix(Hermitian(randn(T, n, n))) + λ = eigvals!(copy(A)) + λ_ad, ∂λ_ad = frule((Zero(), copy(ΔA)), eigvals!, copy(A)) + @test λ_ad ≈ λ # inexact because frule uses eigen not eigvals + @test ∂λ_ad isa typeof(λ) + @test ∂λ_ad ≈ jvp(_fdm, A -> eigvals(Matrix(Hermitian(A))), (A, ΔA)) + end + + @testset "eigvals(::Matrix{$T})" for T in (Float64, ComplexF64) + A, Δλ = Matrix(Hermitian(randn(T, n, n))), randn(n) + λ = eigvals(A) + λ_ad, back = rrule(eigvals, A) + @test λ_ad ≈ λ # inexact because rrule uses eigen not eigvals + ∂self, ∂A = @inferred back(Δλ) + @test ∂self === NO_FIELDS + @test ∂A isa typeof(A) + @test ∂A ≈ j′vp(_fdm, A -> eigvals(Matrix(Hermitian(A))), Δλ, A)[1] + @test @inferred(back(Zero())) == (NO_FIELDS, Zero()) + end + end end end From c2921093f08c1d53aff67dd776cc80186c16a610 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 16 Dec 2020 02:17:56 -0800 Subject: [PATCH 57/62] Do less expensive eltype check first --- src/rulesets/LinearAlgebra/factorization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 69e37ae4b..a0bd74fe5 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -105,7 +105,7 @@ function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{ λ, V = F.values, F.vectors Δλ, ΔV = ΔF.values, ΔF.vectors ΔV isa AbstractZero && Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV) - if ishermitian(A) + if eltype(λ) <: Real && ishermitian(A) hermA = Hermitian(A) ∂V = ΔV isa AbstractZero ? ΔV : copyto!(similar(ΔV), ΔV) ∂hermA = eigen_rev!(hermA, λ, V, Δλ, ∂V) From 22d02476d2a0d9b5ea5919e81d89b47a791d05f3 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 16 Dec 2020 02:29:23 -0800 Subject: [PATCH 58/62] Correctly handle sortby default --- src/rulesets/LinearAlgebra/factorization.jl | 5 +++-- 1 file changed, 3 insertions(+), 2 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index a0bd74fe5..9b081d9e7 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -78,7 +78,7 @@ end function frule((_, ΔA), ::typeof(eigen!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} ΔA isa AbstractZero && return (eigen!(A; kwargs...), ΔA) if ishermitian(A) - sortby = get(kwargs, :sortby, nothing) + sortby = get(kwargs, :sortby, VERSION ≥ v"1.2.0" ? LinearAlgebra.eigsortby : nothing) return if sortby === nothing frule((Zero(), Hermitian(ΔA)), eigen!, Hermitian(A)) else @@ -189,7 +189,8 @@ function frule((_, ΔA), ::typeof(eigvals!), A::StridedMatrix{T}; kwargs...) whe ΔA isa AbstractZero && return eigvals!(A; kwargs...), ΔA if ishermitian(A) λ, ∂λ = frule((Zero(), Hermitian(ΔA)), eigvals!, Hermitian(A)) - _sorteig!_fwd(∂λ, λ, get(kwargs, :sortby, nothing)) + sortby = get(kwargs, :sortby, VERSION ≥ v"1.2.0" ? LinearAlgebra.eigsortby : nothing) + _sorteig!_fwd(∂λ, λ, sortby) else F = eigen!(A; kwargs...) λ, V = F.values, F.vectors From 10f0902e54c79742e0b846a2a5e4845f85af3abd Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 5 Jan 2021 00:47:34 -0800 Subject: [PATCH 59/62] Increment version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 56a91477f..8c86a91a9 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.39" +version = "0.7.43" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" From 3a1375c23693b4cc192a0cb592a86fde01d01798 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 5 Jan 2021 01:02:41 -0800 Subject: [PATCH 60/62] Add references and notes --- src/rulesets/LinearAlgebra/symmetric.jl | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/src/rulesets/LinearAlgebra/symmetric.jl b/src/rulesets/LinearAlgebra/symmetric.jl index 62c9b7bcb..5a0a32ff5 100644 --- a/src/rulesets/LinearAlgebra/symmetric.jl +++ b/src/rulesets/LinearAlgebra/symmetric.jl @@ -77,6 +77,17 @@ end ##### `eigen!`/`eigen` ##### +# rule is old but the usual references are +# real rules: +# Giles M. B., An extended collection of matrix derivative results for forward and reverse +# mode algorithmic differentiation. +# https://people.maths.ox.ac.uk/gilesm/files/NA-08-01.pdf. +# complex rules: +# Boeddeker C., Hanebrink P., et al, On the Computation of Complex-valued Gradients with +# Application to Statistically Optimum Beamforming. arXiv:1701.00392v2 [cs.NA] +# +# accounting for normalization convention appears in Boeddeker && Hanebrink. +# account for phase convention is unpublished. function frule( (_, ΔA), ::typeof(eigen!), @@ -211,6 +222,9 @@ end ##### `svd` ##### +# NOTE: rrule defined because the `svd` primal mutates after calling `eigen`. +# otherwise, this rule just applies the chain rule and can be removed when mutation +# is supported by reverse-mode AD packages function rrule(::typeof(svd), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}) F = svd(A) function svd_pullback(ΔF::Composite{<:SVD}) @@ -247,6 +261,8 @@ end ##### `svdvals` ##### +# NOTE: rrule defined because `svdvals` calls mutating `svdvals!` internally. +# can be removed when mutation is supported by reverse-mode AD packages function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS.BlasReal,<:StridedMatrix}) λ, back = rrule(eigvals, A) S = abs.(λ) From 2df370861fc541cb99a2efaafcf41923b555829d Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 7 Jan 2021 04:28:25 -0800 Subject: [PATCH 61/62] More clearly name tests --- test/rulesets/LinearAlgebra/symmetric.jl | 13 +++++++------ 1 file changed, 7 insertions(+), 6 deletions(-) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 691c5a26f..599d4f778 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -65,7 +65,7 @@ end n = 10 - @testset "eigen!(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), + @testset "frule for eigen!(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) @@ -92,7 +92,7 @@ @test ∂F_ad.vectors * C ≈ ∂F_stable_fd.vectors end - @testset "eigen(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), + @testset "rrule for eigen(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) @@ -162,7 +162,7 @@ @testset "eigvals!/eigvals" begin n = 10 - @testset "eigvals!(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), + @testset "frule for eigvals!(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) @@ -177,7 +177,7 @@ @test ∂λ_ad ≈ jvp(_fdm, A -> eigvals(SymHerm(A, uplo)), (A, ΔA)) end - @testset "eigvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), + @testset "rrule for eigvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) @@ -214,7 +214,7 @@ end n = 10 - VERSION ≥ v"1.3.0" && @testset "svd(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), + VERSION ≥ v"1.3.0" && @testset "rrule for svd(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) @@ -230,6 +230,7 @@ C = _eigvecs_stabilize_mat(F.U, uplo) + # pull back different combination of non-Zero cotangents @testset for nzprops in ([:U], [:S], [:V], [:Vt], [:U, :S, :V, :Vt]) ∂F = CT(; [s => getproperty(ΔF, s) for s in nzprops]...) ∂F_stable = (; [s => copy(getproperty(ΔF, s)) for s in nzprops]...) @@ -256,7 +257,7 @@ @test @inferred(back(CT())) == (NO_FIELDS, Zero()) end - @testset "svdvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), + @testset "rrule for svdvals(::$SymHerm{$T}) uplo=$uplo" for SymHerm in (Symmetric, Hermitian), T in (SymHerm === Symmetric ? (Float64,) : (Float64, ComplexF64)), uplo in (:L, :U) From 4c3f04324d66c6944ae0039c44a370ca0bffefb0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 7 Jan 2021 04:33:41 -0800 Subject: [PATCH 62/62] Add comment explaining test set --- test/rulesets/LinearAlgebra/symmetric.jl | 4 ++++ 1 file changed, 4 insertions(+) diff --git a/test/rulesets/LinearAlgebra/symmetric.jl b/test/rulesets/LinearAlgebra/symmetric.jl index 599d4f778..6d5cbdfce 100644 --- a/test/rulesets/LinearAlgebra/symmetric.jl +++ b/test/rulesets/LinearAlgebra/symmetric.jl @@ -130,6 +130,10 @@ @test @inferred(back(CT())) == (NO_FIELDS, Zero()) end + # when value used to determine phase convention is low, the usual derivatives + # become unstable, causing the rules to compose poorly in a program. + # this test set checks that the rules compose correctly for the function + # f(A) = I, using eigenvectors, where all sensitivities should cancel @testset "phase convention from low value" begin @testset for min_val in [0, eps(), sqrt(eps()), cbrt(eps()), eps()^(1//4)], uplo in (:U, :L)