Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add eigen and eigvals rules for StridedMatrix #321

Merged
merged 24 commits into from
Dec 9, 2020
Merged
Show file tree
Hide file tree
Changes from 21 commits
Commits
Show all changes
24 commits
Select commit Hold shift + click to select a range
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ Statistics = "10745b16-79ce-11e8-11f9-7d13ad32a3b2"

[compat]
ChainRulesCore = "0.9.21"
ChainRulesTestUtils = "0.5.1"
ChainRulesTestUtils = "0.5.5"
Compat = "3"
FiniteDifferences = "0.11.4"
Reexport = "0.2"
Expand Down
132 changes: 132 additions & 0 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,138 @@ function svd_rev(USV::SVD, Ū, s̄, V̄)
return Ā
end

#####
##### `eigen`
#####

# TODO:
# - support correct differential of phase convention when A is hermitian
# - simplify when A is diagonal
# - support degenerate matrices (see #144)

function frule((_, ΔA), ::typeof(eigen!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
F = eigen!(A; kwargs...)
ΔA isa AbstractZero && return F, ΔA
λ, V = F.values, F.vectors
tmp = V \ ΔA
∂K = tmp * V
∂Kdiag = @view ∂K[diagind(∂K)]
∂λ = eltype(λ) <: Real ? real.(∂Kdiag) : copy(∂Kdiag)
∂K ./= transpose(λ) .- λ
fill!(∂Kdiag, 0)
∂V = mul!(tmp, V, ∂K)
_eigen_norm_phase_fwd!(∂V, A, V)
∂F = Composite{typeof(F)}(values = ∂λ, vectors = ∂V)
return F, ∂F
end

function rrule(::typeof(eigen), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}}
F = eigen(A; kwargs...)
function eigen_pullback(ΔF::Composite{<:Eigen})
λ, V = F.values, F.vectors
Δλ, ΔV = ΔF.values, ΔF.vectors
if ΔV isa AbstractZero
Δλ isa AbstractZero && return (NO_FIELDS, Δλ + ΔV)
∂K = Diagonal(Δλ)
∂A = V' \ ∂K * V'
else
∂V = copyto!(similar(ΔV), ΔV)
_eigen_norm_phase_rev!(∂V, A, V)
∂K = V' * ∂V
∂K ./= λ' .- conj.(λ)
∂K[diagind(∂K)] .= Δλ
∂A = mul!(∂K, V' \ ∂K, V')
end
return NO_FIELDS, T <: Real ? real(∂A) : ∂A
end
eigen_pullback(ΔF::AbstractZero) = (NO_FIELDS, ΔF)
return F, eigen_pullback
end

# mutate ∂V to account for the (arbitrary but consistent) normalization and phase condition
# applied to the eigenvectors.
# these implementations assume the convention used by eigen in LinearAlgebra (i.e. that of
# LAPACK.geevx!; eigenvectors have unit norm, and the element with the largest absolute
# value is real), but they can be specialized for `A`

function _eigen_norm_phase_fwd!(∂V, A, V)
@inbounds for i in axes(V, 2)
v, ∂v = @views V[:, i], ∂V[:, i]
# account for unit normalization
∂c_norm = -real(dot(v, ∂v))
if eltype(V) <: Real
∂c = ∂c_norm
else
# account for rotation of largest element to real
k = _findrealmaxabs2(v)
∂c_phase = -imag(∂v[k]) / real(v[k])
∂c = complex(∂c_norm, ∂c_phase)
end
∂v .+= v .* ∂c
end
return ∂V
end

function _eigen_norm_phase_rev!(∂V, A, V)
@inbounds for i in axes(V, 2)
v, ∂v = @views V[:, i], ∂V[:, i]
∂c = dot(v, ∂v)
# account for unit normalization
∂v .-= real(∂c) .* v
if !(eltype(V) <: Real)
# account for rotation of largest element to real
k = _findrealmaxabs2(v)
@inbounds ∂v[k] -= im * (imag(∂c) / real(v[k]))
end
end
return ∂V
end

# workaround for findmax not taking a mapped function
function _findrealmaxabs2(x)
amax = abs2(first(x))
imax = 1
@inbounds for i in 2:length(x)
xi = x[i]
!isreal(xi) && continue
a = abs2(xi)
a < amax && continue
amax, imax = a, i
end
return imax
end

#####
##### `eigvals`
#####

function frule((_, ΔA), ::typeof(eigvals!), A::StridedMatrix{T}; kwargs...) where {T<:BlasFloat}
ΔA isa AbstractZero && return eigvals!(A; kwargs...), ΔA
F = eigen!(A; kwargs...)
λ, V = F.values, F.vectors
tmp = V \ ΔA
∂λ = similar(λ)
# diag(tmp * V) without computing full matrix product
if eltype(∂λ) <: Real
broadcast!((a, b) -> sum(real ∘ prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V))
else
broadcast!((a, b) -> sum(prod, zip(a, b)), ∂λ, eachrow(tmp), eachcol(V))
end
return λ, ∂λ
end

function rrule(::typeof(eigvals), A::StridedMatrix{T}; kwargs...) where {T<:Union{Real,Complex}}
F = eigen(A; kwargs...)
λ = F.values
function eigvals_pullback(Δλ)
V = F.vectors
∂A = V' \ Diagonal(Δλ) * V'
return NO_FIELDS, T <: Real ? real(∂A) : ∂A
end
eigvals_pullback(Δλ::AbstractZero) = (NO_FIELDS, Δλ)
return λ, eigvals_pullback
end

#####
##### `cholesky`
#####
Expand Down
131 changes: 131 additions & 0 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,137 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
@test ChainRules._eyesubx!(copy(X)) ≈ I - X
end
end
@testset "eigendecomposition" begin
@testset "eigen/eigen!" begin
# NOTE: eigen!/eigen are not type-stable, so neither are their frule/rrule

# avoid implementing to_vec(::Eigen)
f(E::Eigen) = (values=E.values, vectors=E.vectors)

# NOTE: for unstructured matrices, low enough n, and this specific seed, finite
# differences of eigen seems to be stable enough for direct comparison.
# This allows us to directly check differential of normalization/phase
# convention
n = 10

@testset "eigen!(::Matrix{$T}) frule" for T in (Float64,ComplexF64)
X = randn(T, n, n)
Ẋ = rand_tangent(X)
F = eigen!(copy(X))
F_fwd, Ḟ_ad = frule((Zero(), copy(Ẋ)), eigen!, copy(X))
@test F_fwd == F
@test Ḟ_ad isa Composite{typeof(F)}
Ḟ_fd = jvp(_fdm, f ∘ eigen! ∘ copy, (X, Ẋ))
@test Ḟ_ad.values ≈ Ḟ_fd.values
@test Ḟ_ad.vectors ≈ Ḟ_fd.vectors
@test frule((Zero(), Zero()), eigen!, copy(X)) == (F, Zero())

@testset "tangents are real when outputs are" begin
# hermitian matrices have real eigenvalues and, when real, real eigenvectors
X = Matrix(Hermitian(randn(T, n, n)))
Ẋ = Matrix(Hermitian(rand_tangent(X)))
_, Ḟ = frule((Zero(), Ẋ), eigen!, X)
@test eltype(Ḟ.values) <: Real
T <: Real && @test eltype(Ḟ.vectors) <: Real
end
end

@testset "eigen(::Matrix{$T}) rrule" for T in (Float64,ComplexF64)
# NOTE: eigen is not type-stable, so neither are is its rrule
X = randn(T, n, n)
F = eigen(X)
V̄ = rand_tangent(F.vectors)
λ̄ = rand_tangent(F.values)
CT = Composite{typeof(F)}
F_rev, back = rrule(eigen, X)
@test F_rev == F
_, X̄_values_ad = @inferred back(CT(values = λ̄))
@test X̄_values_ad ≈ j′vp(_fdm, x -> eigen(x).values, λ̄, X)[1]
_, X̄_vectors_ad = @inferred back(CT(vectors = V̄))
@test X̄_vectors_ad ≈ j′vp(_fdm, x -> eigen(x).vectors, V̄, X)[1]
F̄ = CT(values = λ̄, vectors = V̄)
s̄elf, X̄_ad = @inferred back(F̄)
@test s̄elf === NO_FIELDS
X̄_fd = j′vp(_fdm, f ∘ eigen, F̄, X)[1]
@test X̄_ad ≈ X̄_fd
@test @inferred(back(Zero())) === (NO_FIELDS, Zero())
F̄zero = CT(values = Zero(), vectors = Zero())
@test @inferred(back(F̄zero)) === (NO_FIELDS, Zero())

T <: Real && @testset "cotangent is real when input is" begin
X = randn(T, n, n)
Ẋ = rand_tangent(X)

F = eigen(X)
V̄ = rand_tangent(F.vectors)
λ̄ = rand_tangent(F.values)
F̄ = Composite{typeof(F)}(values = λ̄, vectors = V̄)
X̄ = rrule(eigen, X)[2](F̄)[2]
@test eltype(X̄) <: Real
end
end

@testset "normalization/phase functions are idempotent" for T in (Float64,ComplexF64)
# this is as much a math check as a code check. because normalization when
# applied repeatedly is idempotent, repeated pushforward/pullback should
# leave the (co)tangent unchanged
X = randn(T, n, n)
Ẋ = rand_tangent(X)
F = eigen(X)

V̇ = rand_tangent(F.vectors)
V̇proj = ChainRules._eigen_norm_phase_fwd!(copy(V̇), X, F.vectors)
@test !isapprox(V̇, V̇proj)
V̇proj2 = ChainRules._eigen_norm_phase_fwd!(copy(V̇proj), X, F.vectors)
@test V̇proj2 ≈ V̇proj

V̄ = rand_tangent(F.vectors)
V̄proj = ChainRules._eigen_norm_phase_rev!(copy(V̄), X, F.vectors)
@test !isapprox(V̄, V̄proj)
V̄proj2 = ChainRules._eigen_norm_phase_rev!(copy(V̄proj), X, F.vectors)
@test V̄proj2 ≈ V̄proj
end
end

@testset "eigvals/eigvals!" begin
# NOTE: eigvals!/eigvals are not type-stable, so neither are their frule/rrule
@testset "eigvals!(::Matrix{$T}) frule" for T in (Float64,ComplexF64)
n = 10
X = randn(T, n, n)
λ = eigvals!(copy(X))
Ẋ = rand_tangent(X)
frule_test(eigvals!, (X, Ẋ))
@test frule((Zero(), Zero()), eigvals!, copy(X)) == (λ, Zero())

@testset "tangents are real when outputs are" begin
# hermitian matrices have real eigenvalues
X = Matrix(Hermitian(randn(T, n, n)))
Ẋ = Matrix(Hermitian(rand_tangent(X)))
_, λ̇ = frule((Zero(), Ẋ), eigvals!, X)
@test eltype(λ̇) <: Real
end
end

@testset "eigvals(::Matrix{$T}) rrule" for T in (Float64,ComplexF64)
n = 10
X = randn(T, n, n)
X̄ = rand_tangent(X)
λ̄ = rand_tangent(eigvals(X))
rrule_test(eigvals, λ̄, (X, X̄))
back = rrule(eigvals, X)[2]
@inferred back(λ̄)
@test @inferred(back(Zero())) === (NO_FIELDS, Zero())

T <: Real && @testset "cotangent is real when input is" begin
X = randn(T, n, n)
λ = eigvals(X)
λ̄ = rand_tangent(λ)
X̄ = rrule(eigvals, X)[2](λ̄)[2]
@test eltype(X̄) <: Real
end
end
end
end
@testset "cholesky" begin
@testset "the thing" begin
X = generate_well_conditioned_matrix(10)
Expand Down