-
Notifications
You must be signed in to change notification settings - Fork 89
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Symmetric/Hermitian matrix function rules #193
Conversation
These rules will probably not be usable with Zygote without some modification even after FluxML/Zygote.jl#366 is merged, due to Zygote not enforcing type constraints. e.g. julia> using Zygote, ChainRules
julia> A = Symmetric(randn(3, 3));
julia> Zygote.pullback(sum ∘ exp, A)[2](1.0)[1]
3×3 Array{Float64,2}:
16.1831 4.00247 2.91251
4.00247 -0.0429689 -0.331864
2.91251 -0.331864 -0.546545
julia> rrule(exp, A)[2](ones(size(A)))[2] |> unthunk # this is identical, and hey, we have type constraints!
3×3 Symmetric{Float64,Array{Float64,2}}:
16.1831 4.00247 2.91251
4.00247 -0.0429689 -0.331864
2.91251 -0.331864 -0.546545
julia> Zygote.pullback(sum ∘ exp, A)[2](im)[1]
3×3 Array{Complex{Float64},2}:
0.0+16.1831im 0.0+4.00247im 0.0+2.91251im
0.0+4.00247im 0.0-0.0429689im 0.0-0.331864im
0.0+2.91251im 0.0-0.331864im 0.0-0.546545im
julia> rrule(exp, A)[2](im*ones(size(A)))[2] |> unthunk # that's...not even close
3×3 Symmetric{Complex{Float64},Array{Complex{Float64},2}}:
0.0-6.6388im 0.0+5.58432im 0.0+5.86321im
0.0+5.58432im 0.0-0.608647im 0.0-0.610593im
0.0+5.86321im 0.0-0.610593im 0.0-1.26019im I'm not certain if there's an easy fix to make the two compatible. |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Cool! And thanks!
I am not well placed to review the maths here. On a quick look thing seem sensible.
A few minor comments :) Hopefully we can find someone who can who's able review more thoroughly :)
Co-authored-by: Nick Robinson <[email protected]>
…ules.jl into symhermpowseries
Δλ = λj - λi | ||
iszero(Δλ) && return T(∂fλi) | ||
# handle round-off error using Maclaurin series of (f(λᵢ + Δλ) - f(λᵢ)) / Δλ wrt Δλ | ||
# and approximating f''(λᵢ) with forward difference (f'(λᵢ + Δλ) - f'(λᵢ)) / Δλ |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
oh that's more clever than what I do and gets you eps^2/3 (I only get eps^1/2), nice! Of course this is all assuming that all quantities are order 1 (or else you need to be careful about relative errors rather than absolute ones), but it's fine.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That's right, yeah. TBH, error analysis is not my thing. I plan to come back and try to improve this in the future. For the moment, this seems to be fine. I haven't been able to construct a random almost-degenerate matrix for which the pushforward/pullback disagrees with finite differences so far.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
well it'll agree to eps^2/3 if what you wrote is correct. I can check the math if you want
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
That'd be great if you have the time!
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup, checks out. The approximation (f(l1) - f(l2))/(l1-l2) ~= (f'(l1) + f'(l2))/2 is accurate to order dl^2. The roundoff error when computing (f(l1) - f(l2))/(l1-l2) is order eps/dl. So it's advantageous to switch to the approximation when eps/dl ~= dl^2 => dl = cbrt(eps). The worst case accuracy is indeed eps^2/3.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Awesome! Thanks for checking that!
Regarding reference, I'd just cite Higham, pretty sure he discusses DK theorem somewhere. |
I made a few changes to ensure that the rules do the right thing when the wrapped array is immutable. This isn't tested though. |
Blocked by JuliaDiff/ChainRulesCore.jl#279 and need to stabilize output type for |
false | ||
end | ||
if istypestable | ||
if ChainRulesTestUtils._is_inferrable(f, A) |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't love that we are using an explictly nonexported function.
Do we even need to be testing @inferred
here, or can we rely on frule_test
to do that for us later
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Unfortunately we cannot use frule_test
or rrule_test
here because we need to work around the outputs of some functions being type-unstable. (and rrule
's won't agree with FiniteDifferences for Symmetric
/Hermitian
, only with them composed with the constructor). I copied over a simplified is_inferrable
.
882: Remove Symmetric/Hermitian matrix function rules r=DhairyaLGandhi a=sethaxen JuliaDiff/ChainRules.jl#193 (which is about to be merged) adds to ChainRules the `rrule`s for matrix functions (i.e. power series functions) of `Symmetric` and `Hermitian` matrices. This PR removes the corresponding adjoints from Zygote. Tests pass locally but won't here until a new CR release is registered. This PR is blocked until then. Co-authored-by: Seth Axen <[email protected]> Co-authored-by: Dhairya Gandhi <[email protected]>
This PR adapts Zygote's adjoint rules for
power series functions (analytic functions)matrix functions of realSymmetric
andHermitian
inputs, originally introduced in FluxML/Zygote.jl#355 (derivation here)It additionally adds
frule
s for the same functions.Hermitian
below also includesSymmetric{<:Real}
.A few notes (
Hermitian
below also includesSymmetric{<:Real}
):It also adds rules forMoved to improve performance of complex numbers JuliaLang/julia#323eigen
andeigvals
, since they share a number of utility functions with the power series rules. Theeigen
rule should probably be a fallback for add 128-bit integer types JuliaLang/julia#179.Hermitian
, etc, notComposite{<:Hermitian}
(relates treat newline as space in argument lists JuliaLang/julia#191).Hermitian
matrices having real eigenvalues means that therrule
s are not dependent on the conjugation convention of the cotangent.To-do:
Because only a few utility functions differ between the
Hermitian{<:Complex}
case and theSymmetric{<:Real}
andHermitian{<:Real}
cases, I don't think not having FD v0.10.0 support to test the rules with complex inputs should hold up this PR.