From 9bcbf150bfebb2e51668d27608a55849b5941fd6 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 7 Jan 2021 04:49:03 -0800 Subject: [PATCH 1/5] Remove rules for symmetric eigen and eigvals --- src/lib/array.jl | 26 -------------------------- 1 file changed, 26 deletions(-) diff --git a/src/lib/array.jl b/src/lib/array.jl index 714a70c53..733026c84 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -570,32 +570,6 @@ end return (Ā,) end -@adjoint function LinearAlgebra.eigen(A::LinearAlgebra.RealHermSymComplexHerm) - dU = eigen(A) - return dU, function (Δ) - d, U = dU - d̄, Ū = Δ - if Ū === nothing - P = Diagonal(d̄) - else - F = inv.(d' .- d) - P = F .* (U' * Ū) - if d̄ === nothing - P[diagind(P)] .= 0 - else - P[diagind(P)] = d̄ - end - end - return (U * P * U',) - end -end - -@adjoint function LinearAlgebra.eigvals(A::LinearAlgebra.RealHermSymComplexHerm) - d, U = eigen(A) - return d, d̄ -> (U * Diagonal(d̄) * U',) -end - - # Hermitian/Symmetric matrix functions that can be written as power series _realifydiag!(A::AbstractArray{<:Real}) = A function _realifydiag!(A) From 5c8e8429efbfa2473cf0d81e0cf16fd34040a633 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 7 Jan 2021 04:50:30 -0800 Subject: [PATCH 2/5] Increment required ChainRules version --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 461bf3710..53e24ae26 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" -ChainRules = "0.7.34" +ChainRules = "0.7.44" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11" ForwardDiff = "0.10" From a3c3665b5cd3d6ecc170a62eb7305e0e27b856b8 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 7 Jan 2021 16:55:46 -0800 Subject: [PATCH 3/5] Increment required version number --- Project.toml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/Project.toml b/Project.toml index 53e24ae26..64b7962a1 100644 --- a/Project.toml +++ b/Project.toml @@ -22,7 +22,7 @@ ZygoteRules = "700de1a5-db45-46bc-99cf-38207098b444" [compat] AbstractFFTs = "0.5, 1.0" -ChainRules = "0.7.44" +ChainRules = "0.7.45" DiffRules = "1.0" FillArrays = "0.8, 0.9, 0.10, 0.11" ForwardDiff = "0.10" From fcf555907199675ca6a761b16674da611d547de0 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Thu, 7 Jan 2021 17:57:37 -0800 Subject: [PATCH 4/5] Remove eigen/eigvals tests --- test/gradcheck.jl | 43 ------------------------------------------- 1 file changed, 43 deletions(-) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index d36ea68fc..f81f2aba6 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -775,49 +775,6 @@ function _gradtest_hermsym(f, ST, A) end end -@testset "eigen(::RealHermSymComplexHerm)" begin - MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64}) - rng, N = MersenneTwister(123), 7 - @testset "eigen(::$MT)" for MT in MTs - T = eltype(MT) - ST = _hermsymtype(MT) - - A = ST(randn(rng, T, N, N)) - U = eigvecs(A) - - @test _gradtest_hermsym(ST, A) do (A) - d, U = eigen(A) - return U * Diagonal(exp.(d)) * U' - end - - y = Zygote.pullback(eigen, A)[1] - y2 = eigen(A) - @test y.values ≈ y2.values - @test y.vectors ≈ y2.vectors - - @testset "low rank" begin - A2 = Symmetric(U * Diagonal([randn(rng), zeros(N-1)...]) * U') - @test_broken _gradtest_hermsym(ST, A2) do (A) - d, U = eigen(A) - return U * Diagonal(exp.(d)) * U' - end - end - end -end - -@testset "eigvals(::RealHermSymComplexHerm)" begin - MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64}) - rng, N = MersenneTwister(123), 7 - @testset "eigvals(::$MT)" for MT in MTs - T = eltype(MT) - ST = _hermsymtype(MT) - - A = ST(randn(rng, T, N, N)) - @test _gradtest_hermsym(A ->eigvals(A), ST, A) - @test Zygote.pullback(eigvals, A)[1] ≈ eigvals(A) - end -end - _randmatunitary(rng, T, n) = qr(randn(rng, T, n, n)).Q function _randvectorin(rng, n, r) l, u = r From d747b639a9eaa469d819414b124d3f87f5f61e81 Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Sat, 9 Jan 2021 12:28:15 -0800 Subject: [PATCH 5/5] Revert "Remove eigen/eigvals tests" This reverts commit fcf555907199675ca6a761b16674da611d547de0. --- test/gradcheck.jl | 43 +++++++++++++++++++++++++++++++++++++++++++ 1 file changed, 43 insertions(+) diff --git a/test/gradcheck.jl b/test/gradcheck.jl index f81f2aba6..d36ea68fc 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -775,6 +775,49 @@ function _gradtest_hermsym(f, ST, A) end end +@testset "eigen(::RealHermSymComplexHerm)" begin + MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64}) + rng, N = MersenneTwister(123), 7 + @testset "eigen(::$MT)" for MT in MTs + T = eltype(MT) + ST = _hermsymtype(MT) + + A = ST(randn(rng, T, N, N)) + U = eigvecs(A) + + @test _gradtest_hermsym(ST, A) do (A) + d, U = eigen(A) + return U * Diagonal(exp.(d)) * U' + end + + y = Zygote.pullback(eigen, A)[1] + y2 = eigen(A) + @test y.values ≈ y2.values + @test y.vectors ≈ y2.vectors + + @testset "low rank" begin + A2 = Symmetric(U * Diagonal([randn(rng), zeros(N-1)...]) * U') + @test_broken _gradtest_hermsym(ST, A2) do (A) + d, U = eigen(A) + return U * Diagonal(exp.(d)) * U' + end + end + end +end + +@testset "eigvals(::RealHermSymComplexHerm)" begin + MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64}) + rng, N = MersenneTwister(123), 7 + @testset "eigvals(::$MT)" for MT in MTs + T = eltype(MT) + ST = _hermsymtype(MT) + + A = ST(randn(rng, T, N, N)) + @test _gradtest_hermsym(A ->eigvals(A), ST, A) + @test Zygote.pullback(eigvals, A)[1] ≈ eigvals(A) + end +end + _randmatunitary(rng, T, n) = qr(randn(rng, T, n, n)).Q function _randvectorin(rng, n, r) l, u = r