Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symmetric/Hermitian matrix function rules #193

Merged
merged 95 commits into from
Jan 14, 2021
Merged
Show file tree
Hide file tree
Changes from 25 commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
6d66438
Add symmetric/hermitian eigendecomposition rules
sethaxen May 17, 2020
cac5290
Add utility functions
sethaxen May 17, 2020
bd16f24
Add frules and rrules for sym/herm power series
sethaxen May 17, 2020
bbbabf5
Add int pow rules
sethaxen May 17, 2020
b776223
Add sincos rules
sethaxen May 17, 2020
8d1fdd0
Remove unused function argument
sethaxen May 17, 2020
d50ba1e
Fix and comment _nonzero
sethaxen May 17, 2020
ee3a6fb
Make methods and signatures less ambiguous
sethaxen May 17, 2020
622b5b4
Handle Zero() better
sethaxen May 17, 2020
14b7266
Standardize notation
sethaxen May 17, 2020
57df366
Remove parens
sethaxen May 17, 2020
8c8790a
Update src/rulesets/LinearAlgebra/structured.jl
sethaxen May 17, 2020
0fdd8a5
Merge branch 'symhermpowseries' of https://github.com/sethaxen/ChainR…
sethaxen May 17, 2020
ef15c71
Fix for Julia 1.0
sethaxen May 17, 2020
ad9d36f
Use correct variable and method name
sethaxen May 17, 2020
0078a58
Accumulate in the triangle in the pullback
sethaxen May 17, 2020
5d1685f
Remove comment
sethaxen May 17, 2020
4366b26
Add eigen and eigvals tests
sethaxen May 17, 2020
f6075a5
Remove outdated comment
sethaxen May 18, 2020
90e3225
Clean up and make constraint functions faster
sethaxen May 18, 2020
e35ffef
Make outputs of int pow of Hermitian are Hermitian
sethaxen May 18, 2020
a604e45
Fix typo in comment
sethaxen May 18, 2020
c28304c
Test most power series functions
sethaxen May 18, 2020
e211034
Don't thunk tangents
sethaxen May 21, 2020
3ccfa91
Merge branch 'master' into symhermpowseries
sethaxen May 25, 2020
e4ec19d
Merge branch 'master' into symhermpowseries
sethaxen Jul 1, 2020
6d7c00c
Make type-stable and use optimal threshold
sethaxen Jul 20, 2020
115c201
Merge branch 'master' into symhermpowseries
sethaxen Jul 20, 2020
c877818
Merge branch 'master' into symhermpowseries
sethaxen Nov 20, 2020
f41bfe0
Split out symmetric/hermitian methods/tests
sethaxen Nov 24, 2020
52eef4d
Use correct pullback of hermitrization
sethaxen Nov 24, 2020
af12bed
Stabilize eigenvector computation
sethaxen Nov 25, 2020
46a4ec4
Test composed pullback
sethaxen Nov 25, 2020
e257b11
Remove all eigendecomposition rules
sethaxen Jan 5, 2021
6010c2f
Merge branch 'master' into symhermpowseries
sethaxen Jan 5, 2021
9571d19
Move to utilities section
sethaxen Jan 5, 2021
d7e3762
Move to utilities section
sethaxen Jan 5, 2021
64f96ee
Separate shared code into its own function
sethaxen Jan 5, 2021
109ce2c
Don't thunk
sethaxen Jan 5, 2021
00cccbc
Use correct function name
sethaxen Jan 5, 2021
eb0a7e2
Correctly broadcast
sethaxen Jan 5, 2021
c5a37da
Remove power rules
sethaxen Jan 6, 2021
648e13e
Merge branch 'master' into symhermpowseries
sethaxen Jan 7, 2021
d8b22f1
Rename to matrix functions
sethaxen Jan 8, 2021
0b0cd85
Remove pow tests
sethaxen Jan 8, 2021
f945b13
Expand test suite
sethaxen Jan 8, 2021
2087dac
Remove sincos rules for now
sethaxen Jan 8, 2021
97ec070
Add references and comments
sethaxen Jan 8, 2021
5f1529d
Add _isindomain
sethaxen Jan 8, 2021
537c1f8
Refactor _matfun
sethaxen Jan 8, 2021
65be168
Add _matfun_frechet
sethaxen Jan 8, 2021
2b3e11a
Broadcast instead of indexing
sethaxen Jan 8, 2021
ac0253c
Add comments and use indexing from paper
sethaxen Jan 8, 2021
b7b83f5
Handle Zeros
sethaxen Jan 8, 2021
1c6a889
Contrain differentials according to primals
sethaxen Jan 8, 2021
21340c9
Support all matrix functions
sethaxen Jan 8, 2021
7bf9b7c
Remove unused methods
sethaxen Jan 8, 2021
7d78762
Support Symmetric{Complex}
sethaxen Jan 8, 2021
d1e9947
Add rules for sincos
sethaxen Jan 8, 2021
444a49b
Make atanh rule type-stable
sethaxen Jan 8, 2021
7dcc8a2
Correctly test type-unstable functions
sethaxen Jan 8, 2021
eb52188
Use correct denominator
sethaxen Jan 8, 2021
f78945a
Add tests for almost-singular and low-rank matrices
sethaxen Jan 8, 2021
c8885cd
Remove out-dated comments
sethaxen Jan 8, 2021
b65f552
Test alternate differentials
sethaxen Jan 8, 2021
e6106f3
Don't use only
sethaxen Jan 8, 2021
68b9597
Remove _hermitrizeback!
sethaxen Jan 8, 2021
0778d7a
Don't use hasproperty, not in old Julia versions
sethaxen Jan 8, 2021
77cba6d
Reduce allocations
sethaxen Jan 8, 2021
fc34770
simplify section name
sethaxen Jan 8, 2021
62afe5e
Simplify line
sethaxen Jan 8, 2021
69bcbe8
Handle mixture of non-Zero and Zero
sethaxen Jan 8, 2021
6f6d38b
Don't loop over unused functions
sethaxen Jan 8, 2021
2f6cbeb
Test against component frules instead of fd
sethaxen Jan 8, 2021
673a258
Test that rules produce same uplo as primal
sethaxen Jan 8, 2021
357ecb8
Apply suggestions from code review
sethaxen Jan 12, 2021
bf2191c
Reuse variable name
sethaxen Jan 12, 2021
f1cba00
Use bang bang convention for maybe-in-place
sethaxen Jan 12, 2021
05e8363
Don't assume the wrapped matrix is mutable
sethaxen Jan 12, 2021
8a60771
Replace hermitrize!
sethaxen Jan 12, 2021
8c77687
Use diagind
sethaxen Jan 12, 2021
3ce3d8a
Remove handling of Zero differential
sethaxen Jan 12, 2021
9b40c09
Unify symbols
sethaxen Jan 12, 2021
e9aef74
Use hasproperty
sethaxen Jan 12, 2021
e338822
Load hasproperty from Compat
sethaxen Jan 12, 2021
17686f2
Replace refs with one to Higham
sethaxen Jan 12, 2021
359a1dc
Add docstrings
sethaxen Jan 12, 2021
32a9cca
Update src/rulesets/LinearAlgebra/symmetric.jl
sethaxen Jan 12, 2021
d81ee72
Merge branch 'master' into symhermpowseries
sethaxen Jan 13, 2021
73d5d01
Merge branch 'master' into symhermpowseries
sethaxen Jan 13, 2021
fd60d44
Increment version number
sethaxen Jan 13, 2021
867ea12
Use utility function
sethaxen Jan 13, 2021
b7f6c40
Stabilize jvp Jacobian dimensions
sethaxen Jan 14, 2021
f956f62
Don't use non-exported function
sethaxen Jan 14, 2021
d3ff01a
Bump required ChainRulesCore
sethaxen Jan 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
246 changes: 246 additions & 0 deletions src/rulesets/LinearAlgebra/structured.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,252 @@ end
_symmetric_back(ΔΩ) = UpperTriangular(ΔΩ) + LowerTriangular(ΔΩ)' - Diagonal(ΔΩ)
_symmetric_back(ΔΩ::Union{Diagonal,UpperTriangular}) = ΔΩ

#####
##### `Symmetric{<:Real}`/`Hermitian` eigendecomposition
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
#####

function frule((_, ΔA), ::typeof(eigen), A::LinearAlgebra.RealHermSymComplexHerm)
F = eigen(A)
λ, U = F
∂Λ = U' * ΔA * U
∂λ = real(diag(∂Λ)) # if ΔA is Hermitian, so is ∂Λ, so its diagonal is real
# K is skew-hermitian with zero diag
K = ∂Λ ./ _nonzero.(λ' .- λ)
_setdiag!(K, Zero())
∂U = U * K
∂F = Composite{typeof(F)}(values = ∂λ, vectors = ∂U)
return F, ∂F
end

function rrule(::typeof(eigen), A::LinearAlgebra.RealHermSymComplexHerm)
F = eigen(A)
function eigen_pullback(ΔF)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
∂A = Thunk() do
λ, U = F
∂λ, ∂U = ΔF.values, ΔF.vectors
if ∂U isa AbstractZero
U′∂AU = Diagonal(∂λ)
else
K = U' * ∂U
# unstable for degenerate matrices
U′∂AU = K ./ _nonzero.(λ' .- λ)
_setdiag!(U′∂AU, ∂λ)
end
return _symhermback!(U * U′∂AU * U', A)
end
return NO_FIELDS, ∂A
end
return F, eigen_pullback
end

function frule((_, ΔA), ::typeof(eigvals), A::LinearAlgebra.RealHermSymComplexHerm)
λ, U = eigen(A)
∂Λ = U' * ΔA * U
∂λ = real(diag(∂Λ)) # if ΔA is Hermitian, so is ∂Λ, so its diagonal is real
return λ, ∂λ
end

function rrule(::typeof(eigvals), A::LinearAlgebra.RealHermSymComplexHerm)
F, back = rrule(eigen, A)
λ, U = F
function eigvals_pullback(Δλ)
∂A = Thunk() do
∂F = Composite{typeof(F)}(values = Δλ)
_, ∂A = back(∂F)
return unthunk(∂A)
end
return NO_FIELDS, ∂A
end
return λ, eigvals_pullback
end

# if |x| < eps(), return a small number with its sign, where zero has a positive sign.
_nonzero(x) = ifelse(signbit(x), min(x, -eps(eltype(x))), max(x, eps(eltype(x))))

function _setdiag!(A, d)
for i in axes(A, 1)
@inbounds A[i, i] = d isa AbstractZero ? 0 : d[i]
end
return A
end

_pureimag(x) = x - real(x)

function _realifydiag!(A)
for i in axes(A, 1)
@inbounds A[i, i] = real(A[i, i])
end
return A
end
_realifydiag!(A::AbstractMatrix{<:Real}) = A

_realifydiag(A) = A .- Diagonal(_pureimag.(diag(A)))
_realifydiag(A::AbstractMatrix{<:Real}) = A

_symherm(A::AbstractMatrix{<:Real}, uplo = :U) = Symmetric(A, uplo)
_symherm(A::AbstractMatrix{<:Complex}, uplo = :U) = Hermitian(A, uplo)

_symhermtype(A::Symmetric) = Symmetric
_symhermtype(A::Hermitian) = Hermitian

function _symhermlike!(A, S::LinearAlgebra.RealHermSymComplexHerm)
_realifydiag!(A)
return typeof(S)(A, S.uplo)
end

function _symhermfwd!(A, uplo = :U)
_realifydiag!(A)
return _symherm(A, uplo)
end

# pullback of hermitrization
function _symhermback!(∂A, A)
@inbounds for i in axes(∂A, 1)
for j in 1:(i - 1)
if A.uplo === 'U'
∂A[j, i] += ∂A[i, j]
∂A[i, j] = 0
else
∂A[i, j] += ∂A[j, i]
∂A[j, i] = 0
end
end
if eltype(∂A) <: Complex
∂A[i, i] = real(∂A[i, i])
end
end
return typeof(A)(∂A, A.uplo)
end

#####
##### `Symmetric{<:Real}`/`Hermitian` power series functions
#####

# Currently only defined for series functions whose codomain is ℝ
# These are type-stable and closed under `func`

# The efficient way to do this is probably to AD Base.power_by_squaring
function frule((_, ΔA, _), ::typeof(^), A::LinearAlgebra.RealHermSymComplexHerm, p::Integer)
λ, U = eigen(A)
λᵖ = λ .^ p
Y = U * Diagonal(λᵖ) * U'
_realifydiag!(Y)
Y = _symhermtype(A)(Y, :U)
dλᵖ_dλ = p .* λ .^ (p - 1)
∂Λ = U' * ΔA * U
U′∂YU = _muldiffquotmat(λ, λᵖ, dλᵖ_dλ, ∂Λ)
∂Y = _symhermlike!(U * U′∂YU * U', Y)
return Y, ∂Y
end

function rrule(::typeof(^), A::LinearAlgebra.RealHermSymComplexHerm, p::Integer)
λ, U = eigen(A)
λᵖ = λ .^ p
Y = U * Diagonal(λᵖ) * U'
_realifydiag!(Y)
Y = _symhermtype(A)(Y, :U)
function pow_pullback(ΔY)
∂A = Thunk() do
dλᵖ_dλ = p .* λ .^ (p - 1)
∂Λᵖ = U' * _realifydiag(ΔY) * U
U′∂AU = _muldiffquotmat(λ, λᵖ, dλᵖ_dλ, ∂Λᵖ)
return _symhermback!(U * U′∂AU * U', A)
end
return NO_FIELDS, ∂A, DoesNotExist()
end
return Y, pow_pullback
end

# TODO: support log, sqrt, acos, asin, and non-int pow, which are type-unstable
for func in (:exp, :cos, :sin, :tan, :cosh, :sinh, :tanh, :atan, :asinh, :atanh)
@eval begin
function frule((_, ΔA), ::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm)
df = λi -> frule((Zero(), One()), $func, λi)
λ, U = eigen(A)
fλ_df_dλ = df.(λ)
fλ = first.(fλ_df_dλ)
Y = _symhermfwd!(U * Diagonal(fλ) * U')
df_dλ = last.(unthunk.(fλ_df_dλ))
∂Λ = U' * ΔA * U
U′∂YU = _muldiffquotmat(λ, fλ, df_dλ, ∂Λ)
∂Y = _symhermlike!(U * U′∂YU * U', Y)
return Y, ∂Y
end

function rrule(::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm)
df = λi -> frule((Zero(), One()), $func, λi)
λ, U = eigen(A)
fλ_df_dλ = df.(λ)
fλ = first.(fλ_df_dλ)
Y = _symhermfwd!(U * Diagonal(fλ) * U')
function $(Symbol("$(func)_pullback"))(ΔY)
∂A = Thunk() do
df_dλ = unthunk.(last.(fλ_df_dλ))
∂fΛ = U' * _realifydiag(ΔY) * U
U′∂AU = _muldiffquotmat(λ, fλ, df_dλ, ∂fΛ)
return _symhermback!(U * U′∂AU * U', A)
end
return NO_FIELDS, ∂A
end
return Y, $(Symbol("$(func)_pullback"))
end
end
end

function frule((_, ΔA), ::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm)
λ, U = eigen(A)
sinλ, cosλ = sin.(λ), cos.(λ)
sinA = _symhermfwd!(U * Diagonal(sinλ) * U')
cosA = _symhermfwd!(U * Diagonal(cosλ) * U')
sincosA = (sinA, cosA)
∂Λ = U' * ΔA * U
U′∂sinAU = _muldiffquotmat(λ, sinλ, cosλ, ∂Λ)
∂sinA = _symhermlike!(U * U′∂sinAU * U', sinA)
U′∂cosAU = _muldiffquotmat(λ, cosλ, -sinλ, ∂Λ)
∂cosA = _symhermlike!(U * U′∂cosAU * U', cosA)
∂sincosA = Composite{typeof(sincosA)}(∂sinA, ∂cosA)
return sincosA, ∂sincosA
end

function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm)
λ, U = eigen(A)
sinλ, cosλ = sin.(λ), cos.(λ)
sinA = _symhermfwd!(U * Diagonal(sinλ) * U')
cosA = _symhermfwd!(U * Diagonal(cosλ) * U')
sincosA = (sinA, cosA)
function sincos_pullback(ΔsincosA)
∂A = Thunk() do
ΔsinA, ΔcosA = ΔsincosA
∂sinΛ, ∂cosΛ = U' * _realifydiag(ΔsinA) * U, U' * _realifydiag(ΔcosA) * U
inds = eachindex(λ)
U′∂AU = @inbounds begin
_diffquot.(inds, inds', Ref(λ), Ref(sinλ), Ref(cosλ)) .* ∂sinΛ .+
_diffquot.(inds, inds', Ref(λ), Ref(cosλ), Ref(-sinλ)) .* ∂cosΛ
end
return _symhermback!(U * U′∂AU * U', A)
end
return NO_FIELDS, ∂A
end
return sincosA, sincos_pullback
end

# difference quotient, i.e. Pᵢⱼ = (f(λᵢ) - f(λⱼ)) / (λᵢ - λⱼ), with f'(λᵢ) when i==j
Base.@propagate_inbounds function _diffquot(i, j, λ, fλ, df_dλ)
i == j && return df_dλ[i]
Δλ = λ[i] - λ[j]
T = real(eltype(λ))
# Handle degenerate eigenvalues by taylor expanding Δfλ / Δλ as Δλ → 0
abs2(Δλ) < eps(T) && return (df_dλ[i] + df_dλ[j]) / 2
Δfλ = fλ[i] - fλ[j]
return Δfλ / Δλ
end

# multiply Δ by the matrix of difference quotients P
function _muldiffquotmat(λ, fλ, df_dλ, Δ)
inds = eachindex(λ)
return @inbounds _diffquot.(inds, inds', Ref(λ), Ref(fλ), Ref(df_dλ)) .* Δ
end

#####
##### `Adjoint`
#####
Expand Down
Loading