From fb93a4acde9bfca625a0b8dbc6fc018046124331 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Jan 2021 21:25:43 -0800 Subject: [PATCH 01/10] Remove supeseded adjoints --- src/lib/array.jl | 33 --------------------------------- 1 file changed, 33 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 733026c84..b5cd13d8c 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -673,39 +673,6 @@ function _pullback(cx::AContext, return _pullback(cx, (A, p) -> _apply_series_func(f, A, p), A, p) end -for func in (:exp, :log, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh, :sqrt) - @eval begin - function _pullback(cx::AContext, - f::typeof($func), - A::LinearAlgebra.RealHermSymComplexHerm) - return _pullback(cx, A -> _apply_series_func(f, A), A) - end - end -end - -@adjoint function sincos(A::LinearAlgebra.RealHermSymComplexHerm) - n = LinearAlgebra.checksquare(A) - λ, U = eigen(A) - sλ, cλ = Buffer(λ), Buffer(λ) - for i in Base.OneTo(n) - @inbounds sλ[i], cλ[i] = sincos(λ[i]) - end - sinλ, cosλ = copy(sλ), copy(cλ) - sinA, cosA = U * Diagonal(sinλ) * U', U * Diagonal(cosλ) * U' - Ω, processback = Zygote.pullback(sinA, cosA) do s,c - return (_process_series_matrix(sin, s, A, λ), - _process_series_matrix(cos, c, A, λ)) - end - return Ω, function (Ω̄) - s̄inA, c̄osA = processback(Ω̄) - s̄inΛ, c̄osΛ = U' * s̄inA * U, U' * c̄osA * U - PS = _pairdiffquotmat(sin, n, λ, sinλ, cosλ, -sinλ) - PC = _pairdiffquotmat(cos, n, λ, cosλ, -sinλ, -cosλ) - Ā = U * (PS .* s̄inΛ .+ PC .* c̄osΛ) * U' - return (Ā,) - end -end - # ChainRules has this also but does not use FillArrays, so we have our own definition # for improved performance. See https://github.com/JuliaDiff/ChainRules.jl/issues/46 Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix) From 4450893d96cc24f03c972a7aa422c0c40e07c837 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Jan 2021 21:26:00 -0800 Subject: [PATCH 02/10] Remove now-unused utility functions --- src/lib/array.jl | 20 +------------------- 1 file changed, 1 insertion(+), 19 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index b5cd13d8c..f2747af0c 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -581,10 +581,7 @@ function _realifydiag!(A) end @adjoint _realifydiag!(A) = _realifydiag!(A), Δ -> (_realifydiag!(Δ),) -_hasrealdomain(f, x) = true -_hasrealdomain(::Union{typeof.((acos,asin))...}, x) = all(x -> -1 ≤ x ≤ 1, x) -_hasrealdomain(::typeof(acosh), x) = all(x -> x ≥ 1, x) -_hasrealdomain(::Union{typeof.((log,sqrt,^))...}, x) = all(x -> x ≥ 0, x) +_hasrealdomain(::typeof(^), x) = all(x -> x ≥ 0, x) _process_series_eigvals(f, λ) = _hasrealdomain(f, λ) ? λ : complex.(λ) @@ -598,16 +595,6 @@ _process_series_matrix(::typeof(^), fA, ::Hermitian{<:Complex}, ::AbstractVector # Compute function on eigvals, thunks for conjugates of 1st and 2nd derivatives, # and function to pull back adjoints to args -function _pullback_series_func_scalar(f, λ, args...) - compλ = _process_series_eigvals(f, λ) - fλ, fback = Zygote.pullback((x,args...) -> f.(x, args...), compλ, args...) - n = length(λ) - return (fλ, - ()->fback(ones(n))[1], - ()->nothing, # TODO: add 2nd deriv - isempty(args) ? _ -> () : f̄λ -> tail(fback(f̄λ))) -end - function _pullback_series_func_scalar(f::typeof(^), λ, p) compλ = _process_series_eigvals(f, λ) r, powλ = isinteger(p) ? (Integer(p), λ) : (p, compλ) @@ -618,11 +605,6 @@ function _pullback_series_func_scalar(f::typeof(^), λ, p) f̄λ -> (dot(fλ .* log.(compλ), f̄λ),)) end -function _pullback_series_func_scalar(f::typeof(exp), λ) - expλ = exp.(λ) - return expλ, ()->expλ, ()->expλ, _ -> () -end - _apply_series_func(f, A, args...) = f(A, args...) @adjoint function _apply_series_func(f, A, args...) From 5e05c1e7690ccb10dfaadc9a7eb32019630eda19 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Jan 2021 21:47:52 -0800 Subject: [PATCH 03/10] Custom wrap rrule for exp --- src/lib/array.jl | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/src/lib/array.jl b/src/lib/array.jl index f2747af0c..d5be8f2ed 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -570,6 +570,13 @@ end return (Ā,) end +# The adjoint for exp(::AbstractArray) intercepts ChainRules' rrule for exp(::Hermitian), +# so we call it manually. This can be removed when the generic rule for exp is moved to +# ChainRules +function _pullback(::AContext, ::typeof(exp), A::LinearAlgebra.RealHermSymComplexHerm) + return chain_rrule(exp, A) +end + # Hermitian/Symmetric matrix functions that can be written as power series _realifydiag!(A::AbstractArray{<:Real}) = A function _realifydiag!(A) From 91bedd0740eb408fbf6940b74b63ff29ff2cf3aa Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Tue, 12 Jan 2021 22:56:05 -0800 Subject: [PATCH 04/10] Check that pullback produces cotangent of same type as primal --- test/gradcheck.jl | 9 ++++++++- 1 file changed, 8 insertions(+), 1 deletion(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index d36ea68fc..a23a52aeb 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -865,10 +865,17 @@ _randmatseries(rng, ::typeof(atanh), T, n, domain::Type{Complex}) = nothing @test _gradtest_hermsym(f, ST, A) - y = Zygote.pullback(f, A)[1] + y, back = Zygote.pullback(f, A) y2 = f(A) @test y ≈ y2 @test typeof(y) == typeof(y2) + ȳ = randn(eltype(y), size(y)) + if y isa Union{Symmetric,Hermitian} + ȳ = typeof(y)(ȳ, y.uplo) + end + Ā = back(ȳ)[1] + @test typeof(Ā) == typeof(A) + @test Ā.uplo == A.uplo @testset "similar eigenvalues" begin λ[1] = λ[3] + sqrt(eps(eltype(λ))) / 10 From a0107fb7ef484c151be2626c62e2d90b3b1a0653 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 14 Jan 2021 15:13:02 -0800 Subject: [PATCH 05/10] Bump required ChainRules version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 64b7962a1..506b083b6 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" -ChainRules = "0.7.45" +ChainRules = "0.7.46" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11" ForwardDiff = "0.10" From e682ebf3bf9a05718481449e8c1e3b6d280f5770 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 14 Jan 2021 16:11:57 -0800 Subject: [PATCH 06/10] Increment required ChainRules number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 506b083b6..0c75bab01 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" -ChainRules = "0.7.46" +ChainRules = "0.7.47" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11" ForwardDiff = "0.10" From fbf6b46a79aa5376a0746bf9140f6e4ff33f969e Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 15 Jan 2021 13:24:45 +0530 Subject: [PATCH 07/10] use chain rrule for hermitian Co-authored-by: Seth Axen --- src/lib/array.jl | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index d5be8f2ed..424bc1067 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -574,7 +574,9 @@ end # so we call it manually. This can be removed when the generic rule for exp is moved to # ChainRules function _pullback(::AContext, ::typeof(exp), A::LinearAlgebra.RealHermSymComplexHerm) - return chain_rrule(exp, A) +@adoint exp(A::LinearAlgebra.RealHermSymComplexHerm) + Y, back = chain_rrule(exp, A) + return Y, Δ -> (back(Δ)[2],) end # Hermitian/Symmetric matrix functions that can be written as power series From fcdf983b0c36369638775ab2cc07983164c10431 Mon Sep 17 00:00:00 2001 From: Dhairya Gandhi Date: Fri, 15 Jan 2021 13:27:15 +0530 Subject: [PATCH 08/10] typo --- src/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 424bc1067..5627d0f02 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -574,7 +574,7 @@ end # so we call it manually. This can be removed when the generic rule for exp is moved to # ChainRules function _pullback(::AContext, ::typeof(exp), A::LinearAlgebra.RealHermSymComplexHerm) -@adoint exp(A::LinearAlgebra.RealHermSymComplexHerm) +@adjoint exp(A::LinearAlgebra.RealHermSymComplexHerm) Y, back = chain_rrule(exp, A) return Y, Δ -> (back(Δ)[2],) end From a93c7caa2846251e2bad17a3a1ea935d027e051c Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 15 Jan 2021 00:12:22 -0800 Subject: [PATCH 09/10] Remove old signature --- src/lib/array.jl | 1 - 1 file changed, 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 5627d0f02..c7fb6140c 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -573,7 +573,6 @@ end # The adjoint for exp(::AbstractArray) intercepts ChainRules' rrule for exp(::Hermitian), # so we call it manually. This can be removed when the generic rule for exp is moved to # ChainRules -function _pullback(::AContext, ::typeof(exp), A::LinearAlgebra.RealHermSymComplexHerm) @adjoint exp(A::LinearAlgebra.RealHermSymComplexHerm) Y, back = chain_rrule(exp, A) return Y, Δ -> (back(Δ)[2],) From 044adb3275808ca7c5efb381987a10301e277dc2 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 15 Jan 2021 00:32:10 -0800 Subject: [PATCH 10/10] Add function --- src/lib/array.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index c7fb6140c..102dc09c9 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -573,7 +573,7 @@ end # The adjoint for exp(::AbstractArray) intercepts ChainRules' rrule for exp(::Hermitian), # so we call it manually. This can be removed when the generic rule for exp is moved to # ChainRules -@adjoint exp(A::LinearAlgebra.RealHermSymComplexHerm) +@adjoint function exp(A::LinearAlgebra.RealHermSymComplexHerm) Y, back = chain_rrule(exp, A) return Y, Δ -> (back(Δ)[2],) end