Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symmetric/Hermitian matrix function rules #193

Merged
merged 95 commits into from
Jan 14, 2021
Merged
Show file tree
Hide file tree
Changes from 76 commits
Commits
Show all changes
95 commits
Select commit Hold shift + click to select a range
6d66438
Add symmetric/hermitian eigendecomposition rules
sethaxen May 17, 2020
cac5290
Add utility functions
sethaxen May 17, 2020
bd16f24
Add frules and rrules for sym/herm power series
sethaxen May 17, 2020
bbbabf5
Add int pow rules
sethaxen May 17, 2020
b776223
Add sincos rules
sethaxen May 17, 2020
8d1fdd0
Remove unused function argument
sethaxen May 17, 2020
d50ba1e
Fix and comment _nonzero
sethaxen May 17, 2020
ee3a6fb
Make methods and signatures less ambiguous
sethaxen May 17, 2020
622b5b4
Handle Zero() better
sethaxen May 17, 2020
14b7266
Standardize notation
sethaxen May 17, 2020
57df366
Remove parens
sethaxen May 17, 2020
8c8790a
Update src/rulesets/LinearAlgebra/structured.jl
sethaxen May 17, 2020
0fdd8a5
Merge branch 'symhermpowseries' of https://github.com/sethaxen/ChainR…
sethaxen May 17, 2020
ef15c71
Fix for Julia 1.0
sethaxen May 17, 2020
ad9d36f
Use correct variable and method name
sethaxen May 17, 2020
0078a58
Accumulate in the triangle in the pullback
sethaxen May 17, 2020
5d1685f
Remove comment
sethaxen May 17, 2020
4366b26
Add eigen and eigvals tests
sethaxen May 17, 2020
f6075a5
Remove outdated comment
sethaxen May 18, 2020
90e3225
Clean up and make constraint functions faster
sethaxen May 18, 2020
e35ffef
Make outputs of int pow of Hermitian are Hermitian
sethaxen May 18, 2020
a604e45
Fix typo in comment
sethaxen May 18, 2020
c28304c
Test most power series functions
sethaxen May 18, 2020
e211034
Don't thunk tangents
sethaxen May 21, 2020
3ccfa91
Merge branch 'master' into symhermpowseries
sethaxen May 25, 2020
e4ec19d
Merge branch 'master' into symhermpowseries
sethaxen Jul 1, 2020
6d7c00c
Make type-stable and use optimal threshold
sethaxen Jul 20, 2020
115c201
Merge branch 'master' into symhermpowseries
sethaxen Jul 20, 2020
c877818
Merge branch 'master' into symhermpowseries
sethaxen Nov 20, 2020
f41bfe0
Split out symmetric/hermitian methods/tests
sethaxen Nov 24, 2020
52eef4d
Use correct pullback of hermitrization
sethaxen Nov 24, 2020
af12bed
Stabilize eigenvector computation
sethaxen Nov 25, 2020
46a4ec4
Test composed pullback
sethaxen Nov 25, 2020
e257b11
Remove all eigendecomposition rules
sethaxen Jan 5, 2021
6010c2f
Merge branch 'master' into symhermpowseries
sethaxen Jan 5, 2021
9571d19
Move to utilities section
sethaxen Jan 5, 2021
d7e3762
Move to utilities section
sethaxen Jan 5, 2021
64f96ee
Separate shared code into its own function
sethaxen Jan 5, 2021
109ce2c
Don't thunk
sethaxen Jan 5, 2021
00cccbc
Use correct function name
sethaxen Jan 5, 2021
eb0a7e2
Correctly broadcast
sethaxen Jan 5, 2021
c5a37da
Remove power rules
sethaxen Jan 6, 2021
648e13e
Merge branch 'master' into symhermpowseries
sethaxen Jan 7, 2021
d8b22f1
Rename to matrix functions
sethaxen Jan 8, 2021
0b0cd85
Remove pow tests
sethaxen Jan 8, 2021
f945b13
Expand test suite
sethaxen Jan 8, 2021
2087dac
Remove sincos rules for now
sethaxen Jan 8, 2021
97ec070
Add references and comments
sethaxen Jan 8, 2021
5f1529d
Add _isindomain
sethaxen Jan 8, 2021
537c1f8
Refactor _matfun
sethaxen Jan 8, 2021
65be168
Add _matfun_frechet
sethaxen Jan 8, 2021
2b3e11a
Broadcast instead of indexing
sethaxen Jan 8, 2021
ac0253c
Add comments and use indexing from paper
sethaxen Jan 8, 2021
b7b83f5
Handle Zeros
sethaxen Jan 8, 2021
1c6a889
Contrain differentials according to primals
sethaxen Jan 8, 2021
21340c9
Support all matrix functions
sethaxen Jan 8, 2021
7bf9b7c
Remove unused methods
sethaxen Jan 8, 2021
7d78762
Support Symmetric{Complex}
sethaxen Jan 8, 2021
d1e9947
Add rules for sincos
sethaxen Jan 8, 2021
444a49b
Make atanh rule type-stable
sethaxen Jan 8, 2021
7dcc8a2
Correctly test type-unstable functions
sethaxen Jan 8, 2021
eb52188
Use correct denominator
sethaxen Jan 8, 2021
f78945a
Add tests for almost-singular and low-rank matrices
sethaxen Jan 8, 2021
c8885cd
Remove out-dated comments
sethaxen Jan 8, 2021
b65f552
Test alternate differentials
sethaxen Jan 8, 2021
e6106f3
Don't use only
sethaxen Jan 8, 2021
68b9597
Remove _hermitrizeback!
sethaxen Jan 8, 2021
0778d7a
Don't use hasproperty, not in old Julia versions
sethaxen Jan 8, 2021
77cba6d
Reduce allocations
sethaxen Jan 8, 2021
fc34770
simplify section name
sethaxen Jan 8, 2021
62afe5e
Simplify line
sethaxen Jan 8, 2021
69bcbe8
Handle mixture of non-Zero and Zero
sethaxen Jan 8, 2021
6f6d38b
Don't loop over unused functions
sethaxen Jan 8, 2021
2f6cbeb
Test against component frules instead of fd
sethaxen Jan 8, 2021
673a258
Test that rules produce same uplo as primal
sethaxen Jan 8, 2021
357ecb8
Apply suggestions from code review
sethaxen Jan 12, 2021
bf2191c
Reuse variable name
sethaxen Jan 12, 2021
f1cba00
Use bang bang convention for maybe-in-place
sethaxen Jan 12, 2021
05e8363
Don't assume the wrapped matrix is mutable
sethaxen Jan 12, 2021
8a60771
Replace hermitrize!
sethaxen Jan 12, 2021
8c77687
Use diagind
sethaxen Jan 12, 2021
3ce3d8a
Remove handling of Zero differential
sethaxen Jan 12, 2021
9b40c09
Unify symbols
sethaxen Jan 12, 2021
e9aef74
Use hasproperty
sethaxen Jan 12, 2021
e338822
Load hasproperty from Compat
sethaxen Jan 12, 2021
17686f2
Replace refs with one to Higham
sethaxen Jan 12, 2021
359a1dc
Add docstrings
sethaxen Jan 12, 2021
32a9cca
Update src/rulesets/LinearAlgebra/symmetric.jl
sethaxen Jan 12, 2021
d81ee72
Merge branch 'master' into symhermpowseries
sethaxen Jan 13, 2021
73d5d01
Merge branch 'master' into symhermpowseries
sethaxen Jan 13, 2021
fd60d44
Increment version number
sethaxen Jan 13, 2021
867ea12
Use utility function
sethaxen Jan 13, 2021
b7f6c40
Stabilize jvp Jacobian dimensions
sethaxen Jan 14, 2021
f956f62
Don't use non-exported function
sethaxen Jan 14, 2021
d3ff01a
Bump required ChainRulesCore
sethaxen Jan 14, 2021
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
178 changes: 178 additions & 0 deletions src/rulesets/LinearAlgebra/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -279,6 +279,172 @@ function rrule(::typeof(svdvals), A::LinearAlgebra.RealHermSymComplexHerm{<:BLAS
return S, svdvals_pullback
end

#####
##### matrix functions
#####

# Formula comes from so-called Daleckiĭ-Kreĭn theorem originally due to
# Ju. L. Daleckiĭ and S. G. Kreĭn. Integration and differentiation of functions of Hermitian
# operators and applications to the theory of perturbations.
# Amer. Math. Soc. Transl., Series 2, 47:1–30, 1965.
# Stabilization for almost-degenerate matrices due to
# S. D. Axen, 2020. Representing Ensembles of Molecules.
# Appendix D: Automatic differentation rules for power series functions of diagonalizable matrices
# https://escholarship.org/uc/item/6s62d8pw
# These rules are more stable for degenerate matrices than applying the chain rule to the
# rules for `eigen`.

for func in (:exp, :log, :sqrt, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh)
oxinabox marked this conversation as resolved.
Show resolved Hide resolved
@eval begin
function frule((_, ΔA), ::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm)
ΔA isa AbstractZero && return $func(A), ΔA
Y, intermediates = _matfun($func, A)
Ȳ = _matfun_frechet($func, A, Y, ΔA, intermediates)
# If ΔA was hermitian, then ∂Y has the same structure as Y
∂Y = if ishermitian(ΔA) && (isa(Y, Symmetric) || isa(Y, Hermitian))
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

What if Y is Diagonal ? do we need to worry about that case?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The matrix functions when called on a LinearAlgebra.RealHermSymComplexHerm that wraps a StridedMatrix will always return Union{Symmetric,Hermitian,Matrix}. If someone wraps a Diagonal, all of the matrix functions error at eigen!. The only way one could get a Diagonal would be to implement their own diagonal matrix type, so I think we're safe here.

_symhermlike!(Ȳ, Y)
else
end
return Y, ∂Y
end

function rrule(::typeof($func), A::LinearAlgebra.RealHermSymComplexHerm)
Y, intermediates = _matfun($func, A)
$(Symbol(func, :_pullback))(ΔY::AbstractZero) = (NO_FIELDS, ΔY)
function $(Symbol(func, :_pullback))(ΔY)
# for Hermitian Y, we don't need to realify the diagonal of ΔY, since the
# effect is the same as applying _hermitrize! at the end
∂Y = eltype(Y) <: Real ? real(ΔY) : ΔY
# for matrix functions, the pullback is related to the pushforward by an adjoint
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

because something something linear operator?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

No, a different reason. I don't have a name for the property, but is a property of any function that can be written as a converging power series with real coefficients. I haven't posted the proof anywhere or seen it before, but I'm sure it could be argued from some generic property. For Hermitian matrices, it can be shown by applying the usual inner product trick in the ChainRules docs to derive the pullback.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Perhaps @antoine-levitt knows of this property.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, what property are you talking about?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

sorry, what property are you talking about?

For a matrix function f that can be written in terms of a converging power series with real coefficients,
its pullback f^* is related to its pushforward f_* by an adjoint. Specifically, if Y=f(A), ΔA is a tangent of A, and ΔY is a cotangent of Y (adopting ChainRules'/Zygote's conventions for how a cotangent is represented), then (f^*)_{Y} (ΔY) = (f_*)_{A'} (ΔY). The property means we can write the pullback for any matrix function in terms of its frule.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I got there through a much more complicated route (writing the power series as a recurrence and working out the corresponding pushforwards and pullbacks as recurrences yields this relation), but I'll check to see if we can get there from that property as well. That would be nice, haha.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

hm no it looks more complicated than that, it's df(A)' = df(A') or something like that, I'm not sure it follows from f(A)' = f(A') (but maybe it does, haven't checked carefully). I use this trick in the hermitian case because it means that the differential is self-adjoint on the space of hermitian matrices equipped with the Frobenius metric, but I didn't know about it in the non-hermitian case, it's cute.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Seems like it would be worth me writing this up somewhere.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

OK, so let's do the proof for A^n, n ≥ 1. Then we can pass to arbitrary functions by linearity.

df(A)⋅dA = ∑_{k=0^n-1} A^k dA A^(n-1-k), so let's look at the adjoint of the linear operator L(A) : dA -> A^k1 dA A^k2, and the result follows again by linearity
<L(A) dA, dB> = tr((L(A) dA)' dB) = tr(A^k2' dA' A^k1' dB) = tr(dA' A^k1' dB A^k2') so adj(L(A)) : dB -> A^k1' dB A^k2' so adj(L(A)) = L(A')

That's probably close to the derivation you had?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Okay yes, but yours is much simpler and more concise than mine. Nice!

Ā = _matfun_frechet($func, A, Y, ∂Y', intermediates)
# the cotangent of Hermitian A should be Hermitian
∂A = typeof(A)(eltype(A) <: Real ? real(Ā) : Ā, A.uplo)
_hermitrize!(∂A.data)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
return NO_FIELDS, ∂A
end
return Y, $(Symbol("$(func)_pullback"))
end
end
end

function frule((_, ΔA), ::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm)
ΔA isa AbstractZero && return sincos(A), ΔA
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
sinA, (λ, U, sinλ, cosλ) = _matfun(sin, A)
cosA = _symhermtype(sinA)((U * Diagonal(cosλ)) * U')
tmp = ΔA * U # We will overwrite this matrix several times to hold different values
∂Λ = U' * tmp
∂sinΛ = _muldiffquotmat!(similar(∂Λ), sin, λ, sinλ, cosλ, ∂Λ)
∂cosΛ = _muldiffquotmat!(∂Λ, cos, λ, cosλ, -sinλ, ∂Λ)
∂sinA = _symhermlike!(mul!(∂sinΛ, U, mul!(tmp, ∂sinΛ, U')), sinA)
∂cosA = _symhermlike!(mul!(∂cosΛ, U, mul!(tmp, ∂cosΛ, U')), cosA)
Y = (sinA, cosA)
∂Y = Composite{typeof(Y)}(∂sinA, ∂cosA)
return Y, ∂Y
end

function rrule(::typeof(sincos), A::LinearAlgebra.RealHermSymComplexHerm)
sinA, (λ, U, sinλ, cosλ) = _matfun(sin, A)
cosA = typeof(sinA)((U * Diagonal(cosλ)) * U', sinA.uplo)
Y = (sinA, cosA)
sincos_pullback(ΔY::AbstractZero) = (NO_FIELDS, ΔY)
function sincos_pullback((ΔsinA, ΔcosA)::Composite)
ΔsinA isa AbstractZero && ΔcosA isa AbstractZero && return NO_FIELDS, ΔsinA + ΔcosA
if eltype(A) <: Real
∂sinA, ∂cosA = real(ΔsinA), real(ΔcosA)
else
∂sinA, ∂cosA = ΔsinA, ΔcosA
end
if ∂cosA isa AbstractZero
Ā = _matfun_frechet(sin, A, sinA, ∂sinA, (λ, U, sinλ, cosλ))
elseif ∂sinA isa Zero
Ā = _matfun_frechet(cos, A, cosA, ∂cosA, (λ, U, cosλ, -sinλ))
else
tmp = ∂sinA * U # we will overwrite this with various temporary values during this computation
∂sinΛ = U' * tmp
∂cosΛ = U' * mul!(tmp, ∂cosA, U)
∂Λ = _muldiffquotmat!(∂sinΛ, sin, λ, sinλ, cosλ, ∂sinΛ)
∂Λ = _muldiffquotmat!(∂Λ, cos, λ, cosλ, -sinλ, ∂cosΛ, true)
Ā = mul!(∂Λ, U, mul!(tmp, ∂Λ, U'))
end
_hermitrize!(Ā)
∂A = typeof(A)(Ā, A.uplo)
return NO_FIELDS, ∂A
end
return Y, sincos_pullback
end

# compute the matrix function f(A), returning also a cache of intermediates for computing
# the pushforward or pullback.
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
# Note any function `f` used with this **must** have a `frule` defined on it.
function _matfun(f, A::LinearAlgebra.RealHermSymComplexHerm)
λ, U = eigen(A)
if all(λi -> _isindomain(f, λi), λ)
fλ_df_dλ = map(λi -> frule((Zero(), One()), f, λi), λ)
else # promote to complex if necessary
fλ_df_dλ = map(λi -> frule((Zero(), One()), f, complex(λi)), λ)
end
fλ = first.(fλ_df_dλ)
df_dλ = last.(unthunk.(fλ_df_dλ))
fA = (U * Diagonal(fλ)) * U'
Y = if eltype(A) <: Real
Symmetric(fA)
elseif eltype(fλ) <: Complex
fA
else
Hermitian(fA)
end
intermediates = (λ, U, fλ, df_dλ)
return Y, intermediates
end

# Fréchet derivative of matrix function f
# Computes ∂Y = U * (P .* (U' * ΔA * U)) * U' with fewer allocations
function _matfun_frechet(f, A::LinearAlgebra.RealHermSymComplexHerm, Y, ΔA, (λ, U, fλ, df_dλ))
tmp = ΔA * U
∂Λ = U' * tmp
∂fΛ = _muldiffquotmat!(∂Λ, f, λ, fλ, df_dλ, ∂Λ)
# reuse intermediate if possible
if eltype(tmp) <: Real && eltype(∂fΛ) <: Complex
tmp2 = ∂fΛ * U'
else
tmp2 = mul!(tmp, ∂fΛ, U')
end
∂Y = mul!(∂fΛ, U, tmp2)
return ∂Y
end

# difference quotient, i.e. Pᵢⱼ = (f(λⱼ) - f(λᵢ)) / (λⱼ - λᵢ), with f'(λᵢ) when λᵢ=λⱼ
function _diffquot(f, λi, λj, fλi, fλj, ∂fλi, ∂fλj)
T = Base.promote_typeof(λi, λj, fλi, fλj, ∂fλi, ∂fλj)
Δλ = λj - λi
iszero(Δλ) && return T(∂fλi)
# handle round-off error using Maclaurin series of (f(λᵢ + Δλ) - f(λᵢ)) / Δλ wrt Δλ
# and approximating f''(λᵢ) with forward difference (f'(λᵢ + Δλ) - f'(λᵢ)) / Δλ

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that's more clever than what I do and gets you eps^2/3 (I only get eps^1/2), nice! Of course this is all assuming that all quantities are order 1 (or else you need to be careful about relative errors rather than absolute ones), but it's fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, yeah. TBH, error analysis is not my thing. I plan to come back and try to improve this in the future. For the moment, this seems to be fine. I haven't been able to construct a random almost-degenerate matrix for which the pushforward/pullback disagrees with finite differences so far.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well it'll agree to eps^2/3 if what you wrote is correct. I can check the math if you want

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That'd be great if you have the time!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, checks out. The approximation (f(l1) - f(l2))/(l1-l2) ~= (f'(l1) + f'(l2))/2 is accurate to order dl^2. The roundoff error when computing (f(l1) - f(l2))/(l1-l2) is order eps/dl. So it's advantageous to switch to the approximation when eps/dl ~= dl^2 => dl = cbrt(eps). The worst case accuracy is indeed eps^2/3.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thanks for checking that!

# so (f(λᵢ + Δλ) - f(λᵢ)) / Δλ = (f'(λᵢ + Δλ) + f'(λᵢ)) / 2 + O(Δλ^2)
# total error on the order of f(λᵢ) * eps()^(2/3)
abs(Δλ) < cbrt(eps(real(T))) && return T((∂fλj + ∂fλi) / 2)
Δfλ = fλj - fλi
return T(Δfλ / Δλ)
end

# broadcast multiply Δ by the matrix of difference quotients P, storing the result in PΔ.
# If β is is nonzero, then @. PΔ = β*PΔ + P*Δ
# if type of PΔ is incompatible with result, new matrix is allocated
function _muldiffquotmat!(PΔ, f, λ, fλ, ∂fλ, Δ, β = false)
if eltype(PΔ) <: Real && eltype(fλ) <: Complex
return β .* PΔ .+ _diffquot.(f, λ, λ', fλ, transpose(fλ), ∂fλ, transpose(∂fλ)) .* Δ
else
PΔ .= β .* PΔ .+ _diffquot.(f, λ, λ', fλ, transpose(fλ), ∂fλ, transpose(∂fλ)) .* Δ
return PΔ
end
end

_isindomain(f, x) = true
_isindomain(::Union{typeof(acos),typeof(asin)}, x::Real) = -1 ≤ x ≤ 1
_isindomain(::typeof(acosh), x::Real) = x ≥ 1
_isindomain(::Union{typeof(log),typeof(sqrt)}, x::Real) = x ≥ 0
Comment on lines +443 to +446
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really want a package that just knows about this kind of thing
SpecialFunctionProperties.jl


#####
##### utilities
#####
Expand All @@ -288,6 +454,18 @@ _symhermtype(::Type{<:Symmetric}) = Symmetric
_symhermtype(::Type{<:Hermitian}) = Hermitian
_symhermtype(A) = _symhermtype(typeof(A))

function _realifydiag!(A)
for i in axes(A, 1)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
@inbounds A[i, i] = real(A[i, i])
end
return A
end

function _symhermlike!(A, S::Union{Symmetric,Hermitian})
A isa Hermitian{<:Complex} && _realifydiag!(A)
return typeof(S)(A, S.uplo)
end

# in-place hermitrize matrix
function _hermitrize!(A)
n = size(A, 1)
Expand Down
Loading