Skip to content

Commit

Permalink
Merge #882
Browse files Browse the repository at this point in the history
882: Remove Symmetric/Hermitian matrix function rules r=DhairyaLGandhi a=sethaxen

JuliaDiff/ChainRules.jl#193 (which is about to be merged) adds to ChainRules the `rrule`s for  matrix functions (i.e. power series functions) of `Symmetric` and `Hermitian` matrices. This PR removes the corresponding adjoints from Zygote. Tests pass locally but won't here until a new CR release is registered. This PR is blocked until then.

Co-authored-by: Seth Axen <[email protected]>
Co-authored-by: Dhairya Gandhi <[email protected]>
  • Loading branch information
3 people authored Jan 15, 2021
2 parents 52bcea8 + 044adb3 commit d6e065c
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 54 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444"

[compat]
AbstractFFTs = "0.5, 1.0"
ChainRules = "0.7.45"
ChainRules = "0.7.47"
DiffRules = "1.0"
FillArrays = "0.8, 0.9, 0.10, 0.11"
ForwardDiff = "0.10"
Expand Down
61 changes: 9 additions & 52 deletions src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -570,6 +570,14 @@ end
return (Ā,)
end

# The adjoint for exp(::AbstractArray) intercepts ChainRules' rrule for exp(::Hermitian),
# so we call it manually. This can be removed when the generic rule for exp is moved to
# ChainRules
@adjoint function exp(A::LinearAlgebra.RealHermSymComplexHerm)
Y, back = chain_rrule(exp, A)
return Y, Δ -> (back(Δ)[2],)
end

# Hermitian/Symmetric matrix functions that can be written as power series
_realifydiag!(A::AbstractArray{<:Real}) = A
function _realifydiag!(A)
Expand All @@ -581,10 +589,7 @@ function _realifydiag!(A)
end
@adjoint _realifydiag!(A) = _realifydiag!(A), Δ -> (_realifydiag!(Δ),)

_hasrealdomain(f, x) = true
_hasrealdomain(::Union{typeof.((acos,asin))...}, x) = all(x -> -1 x 1, x)
_hasrealdomain(::typeof(acosh), x) = all(x -> x 1, x)
_hasrealdomain(::Union{typeof.((log,sqrt,^))...}, x) = all(x -> x 0, x)
_hasrealdomain(::typeof(^), x) = all(x -> x 0, x)

_process_series_eigvals(f, λ) = _hasrealdomain(f, λ) ? λ : complex.(λ)

Expand All @@ -598,16 +603,6 @@ _process_series_matrix(::typeof(^), fA, ::Hermitian{<:Complex}, ::AbstractVector

# Compute function on eigvals, thunks for conjugates of 1st and 2nd derivatives,
# and function to pull back adjoints to args
function _pullback_series_func_scalar(f, λ, args...)
compλ = _process_series_eigvals(f, λ)
fλ, fback = Zygote.pullback((x,args...) -> f.(x, args...), compλ, args...)
n = length(λ)
return (fλ,
()->fback(ones(n))[1],
()->nothing, # TODO: add 2nd deriv
isempty(args) ? _ -> () : f̄λ -> tail(fback(f̄λ)))
end

function _pullback_series_func_scalar(f::typeof(^), λ, p)
compλ = _process_series_eigvals(f, λ)
r, powλ = isinteger(p) ? (Integer(p), λ) : (p, compλ)
Expand All @@ -618,11 +613,6 @@ function _pullback_series_func_scalar(f::typeof(^), λ, p)
f̄λ -> (dot(fλ .* log.(compλ), f̄λ),))
end

function _pullback_series_func_scalar(f::typeof(exp), λ)
expλ = exp.(λ)
return expλ, ()->expλ, ()->expλ, _ -> ()
end

_apply_series_func(f, A, args...) = f(A, args...)

@adjoint function _apply_series_func(f, A, args...)
Expand Down Expand Up @@ -673,39 +663,6 @@ function _pullback(cx::AContext,
return _pullback(cx, (A, p) -> _apply_series_func(f, A, p), A, p)
end

for func in (:exp, :log, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh, :sqrt)
@eval begin
function _pullback(cx::AContext,
f::typeof($func),
A::LinearAlgebra.RealHermSymComplexHerm)
return _pullback(cx, A -> _apply_series_func(f, A), A)
end
end
end

@adjoint function sincos(A::LinearAlgebra.RealHermSymComplexHerm)
n = LinearAlgebra.checksquare(A)
λ, U = eigen(A)
sλ, cλ = Buffer(λ), Buffer(λ)
for i in Base.OneTo(n)
@inbounds sλ[i], cλ[i] = sincos(λ[i])
end
sinλ, cosλ = copy(sλ), copy(cλ)
sinA, cosA = U * Diagonal(sinλ) * U', U * Diagonal(cosλ) * U'
Ω, processback = Zygote.pullback(sinA, cosA) do s,c
return (_process_series_matrix(sin, s, A, λ),
_process_series_matrix(cos, c, A, λ))
end
return Ω, function (Ω̄)
s̄inA, c̄osA = processback(Ω̄)
s̄inΛ, c̄osΛ = U' * s̄inA * U, U' * c̄osA * U
PS = _pairdiffquotmat(sin, n, λ, sinλ, cosλ, -sinλ)
PC = _pairdiffquotmat(cos, n, λ, cosλ, -sinλ, -cosλ)
= U * (PS .* s̄inΛ .+ PC .* c̄osΛ) * U'
return (Ā,)
end
end

# ChainRules has this also but does not use FillArrays, so we have our own definition
# for improved performance. See https://github.com/JuliaDiff/ChainRules.jl/issues/46
Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix)
Expand Down
9 changes: 8 additions & 1 deletion test/gradcheck.jl
Original file line number Diff line number Diff line change
Expand Up @@ -865,10 +865,17 @@ _randmatseries(rng, ::typeof(atanh), T, n, domain::Type{Complex}) = nothing

@test _gradtest_hermsym(f, ST, A)

y = Zygote.pullback(f, A)[1]
y, back = Zygote.pullback(f, A)
y2 = f(A)
@test y y2
@test typeof(y) == typeof(y2)
= randn(eltype(y), size(y))
if y isa Union{Symmetric,Hermitian}
= typeof(y)(ȳ, y.uplo)
end
= back(ȳ)[1]
@test typeof(Ā) == typeof(A)
@test.uplo == A.uplo

@testset "similar eigenvalues" begin
λ[1] = λ[3] + sqrt(eps(eltype(λ))) / 10
Expand Down

0 comments on commit d6e065c

Please sign in to comment.