Skip to content

Commit

Permalink
Add eigen and eigvals rules for StridedMatrix (#321)
Browse files Browse the repository at this point in the history
* Add implementation of eigen

* Add implementation of eigvals

* Add todo notes

* Test eigen and eigvals

* Choose dimension that is stable

* Fix function call

* Check that pullbacks are type-stable

* Note why we don't check type-stability for rules

* Add test for idempotence

* Test that sensitivities are real when the primals are

* Rearrange tests

* Test sensitivities are real when primals are for eigvals

* Increment version number

* Don't compute eigenvectors if unused

* Don't compute full matrix product

* Avoid calling Matrix

* Overload mutating versions for frule

* Test mutating form for frule

* Use fewer subscripts

* Increment required patch version

* Increment version number

* Increment version number
  • Loading branch information
sethaxen authored Dec 9, 2020
1 parent 80443b1 commit 22eddb4
Show file tree
Hide file tree
Showing 3 changed files with 266 additions and 2 deletions.
4 changes: 2 additions & 2 deletions Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.37"
version = "0.7.38"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand All @@ -14,7 +14,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9.21"
ChainRulesTestUtils = "0.5.1"
ChainRulesTestUtils = "0.5.5"
Compat = "3"
FiniteDifferences = "0.11.4"
Reexport = "0.2"
Expand Down
132 changes: 132 additions & 0 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,138 @@ function svd_rev(USV::SVD, Ū, s̄, V̄)
return Ā
end

#####
##### `eigen`
#####

# TODO:
# - support correct differential of phase convention when A is hermitian
# - simplify when A is diagonal
# - support degenerate matrices (see #144)

function frule((_, ΔA), ::typeof(eigen!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
F = eigen!(A; kwargs...)
ΔA isa AbstractZero && return F, ΔA
λ, V = F.values, F.vectors
tmp = V \ ΔA
∂K = tmp * V
∂Kdiag = @view ∂K[diagind(∂K)]
∂λ = eltype(λ) <: Real ? real.(∂Kdiag) : copy(∂Kdiag)
∂K ./= transpose(λ) .- λ
fill!(∂Kdiag, 0)
∂V = mul!(tmp, V, ∂K)
_eigen_norm_phase_fwd!(∂V, A, V)
∂F = Composite{typeof(F)}(values = ∂λ, vectors = ∂V)
return F, ∂F
end

function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}}
F = eigen(A; kwargs...)
function eigen_pullback(ΔF::Composite{<:Eigen})
λ, V = F.values, F.vectors
Δλ, ΔV = ΔF.values, ΔF.vectors
if ΔV isa AbstractZero
Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV)
∂K = Diagonal(Δλ)
∂A = V' \ ∂K * V'
else
∂V = copyto!(similar(ΔV), ΔV)
_eigen_norm_phase_rev!(∂V, A, V)
∂K = V' * ∂V
∂K ./= λ' .- conj.(λ)
∂K[diagind(∂K)] .= Δλ
∂A = mul!(∂K, V' \ ∂K, V')
end
return NO_FIELDS, T <: Real ? real(∂A) : ∂A
end
eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF)
return F, eigen_pullback
end

# mutate ∂V to account for the (arbitrary but consistent) normalization and phase condition
# applied to the eigenvectors.
# these implementations assume the convention used by eigen in LinearAlgebra (i.e. that of
# LAPACK.geevx!; eigenvectors have unit norm, and the element with the largest absolute
# value is real), but they can be specialized for `A`

function _eigen_norm_phase_fwd!(∂V, A, V)
@inbounds for i in axes(V, 2)
v, ∂v = @views V[:, i], ∂V[:, i]
# account for unit normalization
∂c_norm = -real(dot(v, ∂v))
if eltype(V) <: Real
∂c = ∂c_norm
else
# account for rotation of largest element to real
k = _findrealmaxabs2(v)
∂c_phase = -imag(∂v[k]) / real(v[k])
∂c = complex(∂c_norm, ∂c_phase)
end
∂v .+= v .* ∂c
end
return ∂V
end

function _eigen_norm_phase_rev!(∂V, A, V)
@inbounds for i in axes(V, 2)
v, ∂v = @views V[:, i], ∂V[:, i]
∂c = dot(v, ∂v)
# account for unit normalization
∂v .-= real(∂c) .* v
if !(eltype(V) <: Real)
# account for rotation of largest element to real
k = _findrealmaxabs2(v)
@inbounds ∂v[k] -= im * (imag(∂c) / real(v[k]))
end
end
return ∂V
end

# workaround for findmax not taking a mapped function
function _findrealmaxabs2(x)
amax = abs2(first(x))
imax = 1
@inbounds for i in 2:length(x)
xi = x[i]
!isreal(xi) && continue
a = abs2(xi)
a < amax && continue
amax, imax = a, i
end
return imax
end

#####
##### `eigvals`
#####

function frule((_, ΔA), ::typeof(eigvals!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
ΔA isa AbstractZero && return eigvals!(A; kwargs...), ΔA
F = eigen!(A; kwargs...)
λ, V = F.values, F.vectors
tmp = V \ ΔA
∂λ = similar(λ)
# diag(tmp * V) without computing full matrix product
if eltype(∂λ) <: Real
broadcast!((a, b) -> sum(real prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V))
else
broadcast!((a, b) -> sum(prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V))
end
return λ, ∂λ
end

function rrule(::typeof(eigvals), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}}
F = eigen(A; kwargs...)
λ = F.values
function eigvals_pullback(Δλ)
V = F.vectors
∂A = V' \ Diagonal(Δλ) * V'
return NO_FIELDS, T <: Real ? real(∂A) : ∂A
end
eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ)
return λ, eigvals_pullback
end

#####
##### `cholesky`
#####
Expand Down
132 changes: 132 additions & 0 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,138 @@ end
end
end

@testset "eigendecomposition" begin
@testset "eigen/eigen!" begin
# NOTE: eigen!/eigen are not type-stable, so neither are their frule/rrule

# avoid implementing to_vec(::Eigen)
f(E::Eigen) = (values=E.values, vectors=E.vectors)

# NOTE: for unstructured matrices, low enough n, and this specific seed, finite
# differences of eigen seems to be stable enough for direct comparison.
# This allows us to directly check differential of normalization/phase
# convention
n = 10

@testset "eigen!(::Matrix{$T}) frule" for T in (Float64,ComplexF64)
X = randn(T, n, n)
= rand_tangent(X)
F = eigen!(copy(X))
F_fwd, Ḟ_ad = frule((Zero(), copy(Ẋ)), eigen!, copy(X))
@test F_fwd == F
@test Ḟ_ad isa Composite{typeof(F)}
Ḟ_fd = jvp(_fdm, f eigen! copy, (X, Ẋ))
@test Ḟ_ad.values Ḟ_fd.values
@test Ḟ_ad.vectors Ḟ_fd.vectors
@test frule((Zero(), Zero()), eigen!, copy(X)) == (F, Zero())

@testset "tangents are real when outputs are" begin
# hermitian matrices have real eigenvalues and, when real, real eigenvectors
X = Matrix(Hermitian(randn(T, n, n)))
= Matrix(Hermitian(rand_tangent(X)))
_, Ḟ = frule((Zero(), Ẋ), eigen!, X)
@test eltype(Ḟ.values) <: Real
T <: Real && @test eltype(Ḟ.vectors) <: Real
end
end

@testset "eigen(::Matrix{$T}) rrule" for T in (Float64,ComplexF64)
# NOTE: eigen is not type-stable, so neither are is its rrule
X = randn(T, n, n)
F = eigen(X)
= rand_tangent(F.vectors)
λ̄ = rand_tangent(F.values)
CT = Composite{typeof(F)}
F_rev, back = rrule(eigen, X)
@test F_rev == F
_, X̄_values_ad = @inferred back(CT(values = λ̄))
@test X̄_values_ad j′vp(_fdm, x -> eigen(x).values, λ̄, X)[1]
_, X̄_vectors_ad = @inferred back(CT(vectors = V̄))
@test X̄_vectors_ad j′vp(_fdm, x -> eigen(x).vectors, V̄, X)[1]
= CT(values = λ̄, vectors = V̄)
s̄elf, X̄_ad = @inferred back(F̄)
@test s̄elf === NO_FIELDS
X̄_fd = j′vp(_fdm, f eigen, F̄, X)[1]
@test X̄_ad X̄_fd
@test @inferred(back(Zero())) === (NO_FIELDS, Zero())
F̄zero = CT(values = Zero(), vectors = Zero())
@test @inferred(back(F̄zero)) === (NO_FIELDS, Zero())

T <: Real && @testset "cotangent is real when input is" begin
X = randn(T, n, n)
= rand_tangent(X)

F = eigen(X)
= rand_tangent(F.vectors)
λ̄ = rand_tangent(F.values)
= Composite{typeof(F)}(values = λ̄, vectors = V̄)
= rrule(eigen, X)[2](F̄)[2]
@test eltype(X̄) <: Real
end
end

@testset "normalization/phase functions are idempotent" for T in (Float64,ComplexF64)
# this is as much a math check as a code check. because normalization when
# applied repeatedly is idempotent, repeated pushforward/pullback should
# leave the (co)tangent unchanged
X = randn(T, n, n)
= rand_tangent(X)
F = eigen(X)

= rand_tangent(F.vectors)
V̇proj = ChainRules._eigen_norm_phase_fwd!(copy(V̇), X, F.vectors)
@test !isapprox(V̇, V̇proj)
V̇proj2 = ChainRules._eigen_norm_phase_fwd!(copy(V̇proj), X, F.vectors)
@test V̇proj2 V̇proj

= rand_tangent(F.vectors)
V̄proj = ChainRules._eigen_norm_phase_rev!(copy(V̄), X, F.vectors)
@test !isapprox(V̄, V̄proj)
V̄proj2 = ChainRules._eigen_norm_phase_rev!(copy(V̄proj), X, F.vectors)
@test V̄proj2 V̄proj
end
end

@testset "eigvals/eigvals!" begin
# NOTE: eigvals!/eigvals are not type-stable, so neither are their frule/rrule
@testset "eigvals!(::Matrix{$T}) frule" for T in (Float64,ComplexF64)
n = 10
X = randn(T, n, n)
λ = eigvals!(copy(X))
= rand_tangent(X)
frule_test(eigvals!, (X, Ẋ))
@test frule((Zero(), Zero()), eigvals!, copy(X)) == (λ, Zero())

@testset "tangents are real when outputs are" begin
# hermitian matrices have real eigenvalues
X = Matrix(Hermitian(randn(T, n, n)))
= Matrix(Hermitian(rand_tangent(X)))
_, λ̇ = frule((Zero(), Ẋ), eigvals!, X)
@test eltype(λ̇) <: Real
end
end

@testset "eigvals(::Matrix{$T}) rrule" for T in (Float64,ComplexF64)
n = 10
X = randn(T, n, n)
= rand_tangent(X)
λ̄ = rand_tangent(eigvals(X))
rrule_test(eigvals, λ̄, (X, X̄))
back = rrule(eigvals, X)[2]
@inferred back(λ̄)
@test @inferred(back(Zero())) === (NO_FIELDS, Zero())

T <: Real && @testset "cotangent is real when input is" begin
X = randn(T, n, n)
λ = eigvals(X)
λ̄ = rand_tangent(λ)
= rrule(eigvals, X)[2](λ̄)[2]
@test eltype(X̄) <: Real
end
end
end
end

# These tests are generally a bit tricky to write because FiniteDifferences doesn't
# have fantastic support for this stuff at the minute.
@testset "cholesky" begin
Expand Down

4 comments on commit 22eddb4

@sethaxen
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request created: JuliaRegistries/General/26146

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.38 -m "<description of version>" 22eddb479062838b7a3e5541cd94a29360eb8617
git push origin v0.7.38

Also, note the warning: Version 0.7.38 skips over 0.7.37
This can be safely ignored. However, if you want to fix this you can do so. Call register() again after making the fix. This will update the Pull request.

@willtebbutt
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@JuliaRegistrator register()

@JuliaRegistrator
Copy link

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Registration pull request updated: JuliaRegistries/General/26146

After the above pull request is merged, it is recommended that a tag is created on this repository for the registered package version.

This will be done automatically if the Julia TagBot GitHub Action is installed, or can be done manually through the github interface, or via:

git tag -a v0.7.38 -m "<description of version>" 22eddb479062838b7a3e5541cd94a29360eb8617
git push origin v0.7.38

Please sign in to comment.