From 22eddb479062838b7a3e5541cd94a29360eb8617 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Wed, 9 Dec 2020 14:29:17 -0800 Subject: [PATCH] Add eigen and eigvals rules for StridedMatrix (#321) * Add implementation of eigen * Add implementation of eigvals * Add todo notes * Test eigen and eigvals * Choose dimension that is stable * Fix function call * Check that pullbacks are type-stable * Note why we don't check type-stability for rules * Add test for idempotence * Test that sensitivities are real when the primals are * Rearrange tests * Test sensitivities are real when primals are for eigvals * Increment version number * Don't compute eigenvectors if unused * Don't compute full matrix product * Avoid calling Matrix * Overload mutating versions for frule * Test mutating form for frule * Use fewer subscripts * Increment required patch version * Increment version number * Increment version number --- Project.toml | 4 +- src/rulesets/LinearAlgebra/factorization.jl | 132 +++++++++++++++++++ test/rulesets/LinearAlgebra/factorization.jl | 132 +++++++++++++++++++ 3 files changed, 266 insertions(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index cb49e6add..eb7e58431 100644 --- a/Project.toml +++ b/Project.toml @@ -1,6 +1,6 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.37" +version = "0.7.38" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" @@ -14,7 +14,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2" [compat] ChainRulesCore = "0.9.21" -ChainRulesTestUtils = "0.5.1" +ChainRulesTestUtils = "0.5.5" Compat = "3" FiniteDifferences = "0.11.4" Reexport = "0.2" diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 10b0e24a0..00a439f04 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -66,6 +66,138 @@ function svd_rev(USV::SVD, Ū, s̄, V̄) return Ā end +##### +##### `eigen` +##### + +# TODO: +# - support correct differential of phase convention when A is hermitian +# - simplify when A is diagonal +# - support degenerate matrices (see #144) + +function frule((_, ΔA), ::typeof(eigen!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} + F = eigen!(A; kwargs...) + ΔA isa AbstractZero && return F, ΔA + λ, V = F.values, F.vectors + tmp = V \ ΔA + ∂K = tmp * V + ∂Kdiag = @view ∂K[diagind(∂K)] + ∂λ = eltype(λ) <: Real ? real.(∂Kdiag) : copy(∂Kdiag) + ∂K ./= transpose(λ) .- λ + fill!(∂Kdiag, 0) + ∂V = mul!(tmp, V, ∂K) + _eigen_norm_phase_fwd!(∂V, A, V) + ∂F = Composite{typeof(F)}(values = ∂λ, vectors = ∂V) + return F, ∂F +end + +function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}} + F = eigen(A; kwargs...) + function eigen_pullback(ΔF::Composite{<:Eigen}) + λ, V = F.values, F.vectors + Δλ, ΔV = ΔF.values, ΔF.vectors + if ΔV isa AbstractZero + Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV) + ∂K = Diagonal(Δλ) + ∂A = V' \ ∂K * V' + else + ∂V = copyto!(similar(ΔV), ΔV) + _eigen_norm_phase_rev!(∂V, A, V) + ∂K = V' * ∂V + ∂K ./= λ' .- conj.(λ) + ∂K[diagind(∂K)] .= Δλ + ∂A = mul!(∂K, V' \ ∂K, V') + end + return NO_FIELDS, T <: Real ? real(∂A) : ∂A + end + eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF) + return F, eigen_pullback +end + +# mutate ∂V to account for the (arbitrary but consistent) normalization and phase condition +# applied to the eigenvectors. +# these implementations assume the convention used by eigen in LinearAlgebra (i.e. that of +# LAPACK.geevx!; eigenvectors have unit norm, and the element with the largest absolute +# value is real), but they can be specialized for `A` + +function _eigen_norm_phase_fwd!(∂V, A, V) + @inbounds for i in axes(V, 2) + v, ∂v = @views V[:, i], ∂V[:, i] + # account for unit normalization + ∂c_norm = -real(dot(v, ∂v)) + if eltype(V) <: Real + ∂c = ∂c_norm + else + # account for rotation of largest element to real + k = _findrealmaxabs2(v) + ∂c_phase = -imag(∂v[k]) / real(v[k]) + ∂c = complex(∂c_norm, ∂c_phase) + end + ∂v .+= v .* ∂c + end + return ∂V +end + +function _eigen_norm_phase_rev!(∂V, A, V) + @inbounds for i in axes(V, 2) + v, ∂v = @views V[:, i], ∂V[:, i] + ∂c = dot(v, ∂v) + # account for unit normalization + ∂v .-= real(∂c) .* v + if !(eltype(V) <: Real) + # account for rotation of largest element to real + k = _findrealmaxabs2(v) + @inbounds ∂v[k] -= im * (imag(∂c) / real(v[k])) + end + end + return ∂V +end + +# workaround for findmax not taking a mapped function +function _findrealmaxabs2(x) + amax = abs2(first(x)) + imax = 1 + @inbounds for i in 2:length(x) + xi = x[i] + !isreal(xi) && continue + a = abs2(xi) + a < amax && continue + amax, imax = a, i + end + return imax +end + +##### +##### `eigvals` +##### + +function frule((_, ΔA), ::typeof(eigvals!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat} + ΔA isa AbstractZero && return eigvals!(A; kwargs...), ΔA + F = eigen!(A; kwargs...) + λ, V = F.values, F.vectors + tmp = V \ ΔA + ∂λ = similar(λ) + # diag(tmp * V) without computing full matrix product + if eltype(∂λ) <: Real + broadcast!((a, b) -> sum(real ∘ prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V)) + else + broadcast!((a, b) -> sum(prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V)) + end + return λ, ∂λ +end + +function rrule(::typeof(eigvals), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}} + F = eigen(A; kwargs...) + λ = F.values + function eigvals_pullback(Δλ) + V = F.vectors + ∂A = V' \ Diagonal(Δλ) * V' + return NO_FIELDS, T <: Real ? real(∂A) : ∂A + end + eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ) + return λ, eigvals_pullback +end + ##### ##### `cholesky` ##### diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 71e2236fe..416f44169 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -85,6 +85,138 @@ end end end + @testset "eigendecomposition" begin + @testset "eigen/eigen!" begin + # NOTE: eigen!/eigen are not type-stable, so neither are their frule/rrule + + # avoid implementing to_vec(::Eigen) + f(E::Eigen) = (values=E.values, vectors=E.vectors) + + # NOTE: for unstructured matrices, low enough n, and this specific seed, finite + # differences of eigen seems to be stable enough for direct comparison. + # This allows us to directly check differential of normalization/phase + # convention + n = 10 + + @testset "eigen!(::Matrix{$T}) frule" for T in (Float64,ComplexF64) + X = randn(T, n, n) + Ẋ = rand_tangent(X) + F = eigen!(copy(X)) + F_fwd, Ḟ_ad = frule((Zero(), copy(Ẋ)), eigen!, copy(X)) + @test F_fwd == F + @test Ḟ_ad isa Composite{typeof(F)} + Ḟ_fd = jvp(_fdm, f ∘ eigen! ∘ copy, (X, Ẋ)) + @test Ḟ_ad.values ≈ Ḟ_fd.values + @test Ḟ_ad.vectors ≈ Ḟ_fd.vectors + @test frule((Zero(), Zero()), eigen!, copy(X)) == (F, Zero()) + + @testset "tangents are real when outputs are" begin + # hermitian matrices have real eigenvalues and, when real, real eigenvectors + X = Matrix(Hermitian(randn(T, n, n))) + Ẋ = Matrix(Hermitian(rand_tangent(X))) + _, Ḟ = frule((Zero(), Ẋ), eigen!, X) + @test eltype(Ḟ.values) <: Real + T <: Real && @test eltype(Ḟ.vectors) <: Real + end + end + + @testset "eigen(::Matrix{$T}) rrule" for T in (Float64,ComplexF64) + # NOTE: eigen is not type-stable, so neither are is its rrule + X = randn(T, n, n) + F = eigen(X) + V̄ = rand_tangent(F.vectors) + λ̄ = rand_tangent(F.values) + CT = Composite{typeof(F)} + F_rev, back = rrule(eigen, X) + @test F_rev == F + _, X̄_values_ad = @inferred back(CT(values = λ̄)) + @test X̄_values_ad ≈ j′vp(_fdm, x -> eigen(x).values, λ̄, X)[1] + _, X̄_vectors_ad = @inferred back(CT(vectors = V̄)) + @test X̄_vectors_ad ≈ j′vp(_fdm, x -> eigen(x).vectors, V̄, X)[1] + F̄ = CT(values = λ̄, vectors = V̄) + s̄elf, X̄_ad = @inferred back(F̄) + @test s̄elf === NO_FIELDS + X̄_fd = j′vp(_fdm, f ∘ eigen, F̄, X)[1] + @test X̄_ad ≈ X̄_fd + @test @inferred(back(Zero())) === (NO_FIELDS, Zero()) + F̄zero = CT(values = Zero(), vectors = Zero()) + @test @inferred(back(F̄zero)) === (NO_FIELDS, Zero()) + + T <: Real && @testset "cotangent is real when input is" begin + X = randn(T, n, n) + Ẋ = rand_tangent(X) + + F = eigen(X) + V̄ = rand_tangent(F.vectors) + λ̄ = rand_tangent(F.values) + F̄ = Composite{typeof(F)}(values = λ̄, vectors = V̄) + X̄ = rrule(eigen, X)[2](F̄)[2] + @test eltype(X̄) <: Real + end + end + + @testset "normalization/phase functions are idempotent" for T in (Float64,ComplexF64) + # this is as much a math check as a code check. because normalization when + # applied repeatedly is idempotent, repeated pushforward/pullback should + # leave the (co)tangent unchanged + X = randn(T, n, n) + Ẋ = rand_tangent(X) + F = eigen(X) + + V̇ = rand_tangent(F.vectors) + V̇proj = ChainRules._eigen_norm_phase_fwd!(copy(V̇), X, F.vectors) + @test !isapprox(V̇, V̇proj) + V̇proj2 = ChainRules._eigen_norm_phase_fwd!(copy(V̇proj), X, F.vectors) + @test V̇proj2 ≈ V̇proj + + V̄ = rand_tangent(F.vectors) + V̄proj = ChainRules._eigen_norm_phase_rev!(copy(V̄), X, F.vectors) + @test !isapprox(V̄, V̄proj) + V̄proj2 = ChainRules._eigen_norm_phase_rev!(copy(V̄proj), X, F.vectors) + @test V̄proj2 ≈ V̄proj + end + end + + @testset "eigvals/eigvals!" begin + # NOTE: eigvals!/eigvals are not type-stable, so neither are their frule/rrule + @testset "eigvals!(::Matrix{$T}) frule" for T in (Float64,ComplexF64) + n = 10 + X = randn(T, n, n) + λ = eigvals!(copy(X)) + Ẋ = rand_tangent(X) + frule_test(eigvals!, (X, Ẋ)) + @test frule((Zero(), Zero()), eigvals!, copy(X)) == (λ, Zero()) + + @testset "tangents are real when outputs are" begin + # hermitian matrices have real eigenvalues + X = Matrix(Hermitian(randn(T, n, n))) + Ẋ = Matrix(Hermitian(rand_tangent(X))) + _, λ̇ = frule((Zero(), Ẋ), eigvals!, X) + @test eltype(λ̇) <: Real + end + end + + @testset "eigvals(::Matrix{$T}) rrule" for T in (Float64,ComplexF64) + n = 10 + X = randn(T, n, n) + X̄ = rand_tangent(X) + λ̄ = rand_tangent(eigvals(X)) + rrule_test(eigvals, λ̄, (X, X̄)) + back = rrule(eigvals, X)[2] + @inferred back(λ̄) + @test @inferred(back(Zero())) === (NO_FIELDS, Zero()) + + T <: Real && @testset "cotangent is real when input is" begin + X = randn(T, n, n) + λ = eigvals(X) + λ̄ = rand_tangent(λ) + X̄ = rrule(eigvals, X)[2](λ̄)[2] + @test eltype(X̄) <: Real + end + end + end + end + # These tests are generally a bit tricky to write because FiniteDifferences doesn't # have fantastic support for this stuff at the minute. @testset "cholesky" begin