diff --git a/Project.toml b/Project.toml index 64b7962a1..0c75bab01 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" -ChainRules = "0.7.45" +ChainRules = "0.7.47" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11" ForwardDiff = "0.10" diff --git a/src/lib/array.jl b/src/lib/array.jl index 733026c84..102dc09c9 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -570,6 +570,14 @@ end return (Ā,) end +# The adjoint for exp(::AbstractArray) intercepts ChainRules' rrule for exp(::Hermitian), +# so we call it manually. This can be removed when the generic rule for exp is moved to +# ChainRules +@adjoint function exp(A::LinearAlgebra.RealHermSymComplexHerm) + Y, back = chain_rrule(exp, A) + return Y, Δ -> (back(Δ)[2],) +end + # Hermitian/Symmetric matrix functions that can be written as power series _realifydiag!(A::AbstractArray{<:Real}) = A function _realifydiag!(A) @@ -581,10 +589,7 @@ function _realifydiag!(A) end @adjoint _realifydiag!(A) = _realifydiag!(A), Δ -> (_realifydiag!(Δ),) -_hasrealdomain(f, x) = true -_hasrealdomain(::Union{typeof.((acos,asin))...}, x) = all(x -> -1 ≤ x ≤ 1, x) -_hasrealdomain(::typeof(acosh), x) = all(x -> x ≥ 1, x) -_hasrealdomain(::Union{typeof.((log,sqrt,^))...}, x) = all(x -> x ≥ 0, x) +_hasrealdomain(::typeof(^), x) = all(x -> x ≥ 0, x) _process_series_eigvals(f, λ) = _hasrealdomain(f, λ) ? λ : complex.(λ) @@ -598,16 +603,6 @@ _process_series_matrix(::typeof(^), fA, ::Hermitian{<:Complex}, ::AbstractVector # Compute function on eigvals, thunks for conjugates of 1st and 2nd derivatives, # and function to pull back adjoints to args -function _pullback_series_func_scalar(f, λ, args...) - compλ = _process_series_eigvals(f, λ) - fλ, fback = Zygote.pullback((x,args...) -> f.(x, args...), compλ, args...) - n = length(λ) - return (fλ, - ()->fback(ones(n))[1], - ()->nothing, # TODO: add 2nd deriv - isempty(args) ? _ -> () : f̄λ -> tail(fback(f̄λ))) -end - function _pullback_series_func_scalar(f::typeof(^), λ, p) compλ = _process_series_eigvals(f, λ) r, powλ = isinteger(p) ? (Integer(p), λ) : (p, compλ) @@ -618,11 +613,6 @@ function _pullback_series_func_scalar(f::typeof(^), λ, p) f̄λ -> (dot(fλ .* log.(compλ), f̄λ),)) end -function _pullback_series_func_scalar(f::typeof(exp), λ) - expλ = exp.(λ) - return expλ, ()->expλ, ()->expλ, _ -> () -end - _apply_series_func(f, A, args...) = f(A, args...) @adjoint function _apply_series_func(f, A, args...) @@ -673,39 +663,6 @@ function _pullback(cx::AContext, return _pullback(cx, (A, p) -> _apply_series_func(f, A, p), A, p) end -for func in (:exp, :log, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh, :sqrt) - @eval begin - function _pullback(cx::AContext, - f::typeof($func), - A::LinearAlgebra.RealHermSymComplexHerm) - return _pullback(cx, A -> _apply_series_func(f, A), A) - end - end -end - -@adjoint function sincos(A::LinearAlgebra.RealHermSymComplexHerm) - n = LinearAlgebra.checksquare(A) - λ, U = eigen(A) - sλ, cλ = Buffer(λ), Buffer(λ) - for i in Base.OneTo(n) - @inbounds sλ[i], cλ[i] = sincos(λ[i]) - end - sinλ, cosλ = copy(sλ), copy(cλ) - sinA, cosA = U * Diagonal(sinλ) * U', U * Diagonal(cosλ) * U' - Ω, processback = Zygote.pullback(sinA, cosA) do s,c - return (_process_series_matrix(sin, s, A, λ), - _process_series_matrix(cos, c, A, λ)) - end - return Ω, function (Ω̄) - s̄inA, c̄osA = processback(Ω̄) - s̄inΛ, c̄osΛ = U' * s̄inA * U, U' * c̄osA * U - PS = _pairdiffquotmat(sin, n, λ, sinλ, cosλ, -sinλ) - PC = _pairdiffquotmat(cos, n, λ, cosλ, -sinλ, -cosλ) - Ā = U * (PS .* s̄inΛ .+ PC .* c̄osΛ) * U' - return (Ā,) - end -end - # ChainRules has this also but does not use FillArrays, so we have our own definition # for improved performance. See https://github.com/JuliaDiff/ChainRules.jl/issues/46 Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index d36ea68fc..a23a52aeb 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -865,10 +865,17 @@ _randmatseries(rng, ::typeof(atanh), T, n, domain::Type{Complex}) = nothing @test _gradtest_hermsym(f, ST, A) - y = Zygote.pullback(f, A)[1] + y, back = Zygote.pullback(f, A) y2 = f(A) @test y ≈ y2 @test typeof(y) == typeof(y2) + ȳ = randn(eltype(y), size(y)) + if y isa Union{Symmetric,Hermitian} + ȳ = typeof(y)(ȳ, y.uplo) + end + Ā = back(ȳ)[1] + @test typeof(Ā) == typeof(A) + @test Ā.uplo == A.uplo @testset "similar eigenvalues" begin λ[1] = λ[3] + sqrt(eps(eltype(λ))) / 10