Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Symmetric/Hermitian matrix function rules #193

Merged
merged 95 commits into from
Jan 14, 2021

Conversation

sethaxen
Copy link
Member

@sethaxen sethaxen commented May 17, 2020

This PR adapts Zygote's adjoint rules for power series functions (analytic functions) matrix functions of real Symmetric and Hermitian inputs, originally introduced in FluxML/Zygote.jl#355 (derivation here)
It additionally adds frules for the same functions. Hermitian below also includes Symmetric{<:Real}.

A few notes (Hermitian below also includes Symmetric{<:Real}):

To-do:

Because only a few utility functions differ between the Hermitian{<:Complex} case and the Symmetric{<:Real} and Hermitian{<:Real} cases, I don't think not having FD v0.10.0 support to test the rules with complex inputs should hold up this PR.

@sethaxen
Copy link
Member Author

These rules will probably not be usable with Zygote without some modification even after FluxML/Zygote.jl#366 is merged, due to Zygote not enforcing type constraints. e.g.

julia> using Zygote, ChainRules

julia> A = Symmetric(randn(3, 3));

julia> Zygote.pullback(sum  exp, A)[2](1.0)[1]
3×3 Array{Float64,2}:
 16.1831    4.00247     2.91251
  4.00247  -0.0429689  -0.331864
  2.91251  -0.331864   -0.546545

julia> rrule(exp, A)[2](ones(size(A)))[2] |> unthunk # this is identical, and hey, we have type constraints!
3×3 Symmetric{Float64,Array{Float64,2}}:
 16.1831    4.00247     2.91251
  4.00247  -0.0429689  -0.331864
  2.91251  -0.331864   -0.546545

julia> Zygote.pullback(sum  exp, A)[2](im)[1]
3×3 Array{Complex{Float64},2}:
 0.0+16.1831im  0.0+4.00247im    0.0+2.91251im
 0.0+4.00247im  0.0-0.0429689im  0.0-0.331864im
 0.0+2.91251im  0.0-0.331864im   0.0-0.546545im

julia> rrule(exp, A)[2](im*ones(size(A)))[2] |> unthunk # that's...not even close
3×3 Symmetric{Complex{Float64},Array{Complex{Float64},2}}:
 0.0-6.6388im   0.0+5.58432im   0.0+5.86321im
 0.0+5.58432im  0.0-0.608647im  0.0-0.610593im
 0.0+5.86321im  0.0-0.610593im  0.0-1.26019im

I'm not certain if there's an easy fix to make the two compatible.

Copy link
Contributor

@nickrobinson251 nickrobinson251 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Cool! And thanks!

I am not well placed to review the maths here. On a quick look thing seem sensible.
A few minor comments :) Hopefully we can find someone who can who's able review more thoroughly :)

src/rulesets/LinearAlgebra/structured.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/structured.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/structured.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/structured.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/structured.jl Outdated Show resolved Hide resolved
src/rulesets/LinearAlgebra/structured.jl Outdated Show resolved Hide resolved
Δλ = λj - λi
iszero(Δλ) && return T(∂fλi)
# handle round-off error using Maclaurin series of (f(λᵢ + Δλ) - f(λᵢ)) / Δλ wrt Δλ
# and approximating f''(λᵢ) with forward difference (f'(λᵢ + Δλ) - f'(λᵢ)) / Δλ

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

oh that's more clever than what I do and gets you eps^2/3 (I only get eps^1/2), nice! Of course this is all assuming that all quantities are order 1 (or else you need to be careful about relative errors rather than absolute ones), but it's fine.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's right, yeah. TBH, error analysis is not my thing. I plan to come back and try to improve this in the future. For the moment, this seems to be fine. I haven't been able to construct a random almost-degenerate matrix for which the pushforward/pullback disagrees with finite differences so far.

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

well it'll agree to eps^2/3 if what you wrote is correct. I can check the math if you want

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That'd be great if you have the time!

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yup, checks out. The approximation (f(l1) - f(l2))/(l1-l2) ~= (f'(l1) + f'(l2))/2 is accurate to order dl^2. The roundoff error when computing (f(l1) - f(l2))/(l1-l2) is order eps/dl. So it's advantageous to switch to the approximation when eps/dl ~= dl^2 => dl = cbrt(eps). The worst case accuracy is indeed eps^2/3.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Awesome! Thanks for checking that!

@antoine-levitt
Copy link

Regarding reference, I'd just cite Higham, pretty sure he discusses DK theorem somewhere.

@sethaxen
Copy link
Member Author

I made a few changes to ensure that the rules do the right thing when the wrapped array is immutable. This isn't tested though.

@sethaxen
Copy link
Member Author

Blocked by JuliaDiff/ChainRulesCore.jl#279 and need to stabilize output type for jvp.

false
end
if istypestable
if ChainRulesTestUtils._is_inferrable(f, A)
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't love that we are using an explictly nonexported function.
Do we even need to be testing @inferred here, or can we rely on frule_test to do that for us later

Copy link
Member Author

@sethaxen sethaxen Jan 14, 2021

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Unfortunately we cannot use frule_test or rrule_test here because we need to work around the outputs of some functions being type-unstable. (and rrule's won't agree with FiniteDifferences for Symmetric/Hermitian, only with them composed with the constructor). I copied over a simplified is_inferrable.

@sethaxen sethaxen merged commit c59e6f8 into JuliaDiff:master Jan 14, 2021
@sethaxen sethaxen deleted the symhermpowseries branch January 14, 2021 23:10
bors bot added a commit to FluxML/Zygote.jl that referenced this pull request Jan 15, 2021
882: Remove Symmetric/Hermitian matrix function rules r=DhairyaLGandhi a=sethaxen

JuliaDiff/ChainRules.jl#193 (which is about to be merged) adds to ChainRules the `rrule`s for  matrix functions (i.e. power series functions) of `Symmetric` and `Hermitian` matrices. This PR removes the corresponding adjoints from Zygote. Tests pass locally but won't here until a new CR release is registered. This PR is blocked until then.

Co-authored-by: Seth Axen <[email protected]>
Co-authored-by: Dhairya Gandhi <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

Successfully merging this pull request may close these issues.

6 participants