diff --git a/src/lib/array.jl b/src/lib/array.jl index 7cc1101c1..378110af2 100644 --- a/src/lib/array.jl +++ b/src/lib/array.jl @@ -416,6 +416,11 @@ end return H, back end +@adjoint convert(::Type{R}, A::LinearAlgebra.HermOrSym{T,S}) where {T,S,R<:Array} = convert(R, A), + Δ -> (nothing, convert(S, Δ),) +@adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A), + Δ -> (convert(S, Δ),) + @adjoint function cholesky(Σ::Real) C = cholesky(Σ) return C, Δ::NamedTuple->(Δ.factors[1, 1] / (2 * C.U[1, 1]),) @@ -451,6 +456,25 @@ end end end +# Matrix of pairwise difference quotients +Base.@propagate_inbounds function _pairdiffquot(f, i, j, x, fx, dfx, d²fx = nothing) + i == j && return dfx[i] + Δx = x[i] - x[j] + T = real(eltype(x)) + if d²fx === nothing + abs(Δx) ≤ sqrt(eps(T)) && return (dfx[i] + dfx[j]) / 2 + else + abs(Δx) ≤ eps(T)^(1/3) && return dfx[i] - Δx / 2 * d²fx[i] + end + Δfx = fx[i] - fx[j] + return Δfx / Δx +end + +Base.@propagate_inbounds function _pairdiffquotmat(f, n, x, fx, dfx, d²fx = nothing) + Δfij = (i, j)->_pairdiffquot(f, i, j, x, fx, dfx, d²fx) + return Δfij.(Base.OneTo(n), Base.OneTo(n)') +end + # Adjoint based on the Theano implementation, which uses the differential as described # in Brančík, "Matlab programs for matrix exponential function derivative evaluation" @adjoint exp(A::AbstractMatrix) = exp(A), function(F̄) @@ -458,12 +482,13 @@ end E = eigen(A) w = E.values ew = exp.(w) - X = [i==j ? ew[i] : (ew[i]-ew[j])/(w[i]-w[j]) for i in 1:n,j=1:n] + X = _pairdiffquotmat(exp, n, w, ew, ew, ew) V = E.vectors VF = factorize(V) Ā = (V * ((VF \ F̄' * V) .* X) / VF)' return (Ā,) end + @adjoint function LinearAlgebra.eigen(A::LinearAlgebra.RealHermSymComplexHerm) dU = eigen(A) return dU, function (Δ) @@ -489,6 +514,143 @@ end return d, d̄ -> (U * Diagonal(d̄) * U',) end + +# Hermitian/Symmetric matrix functions that can be written as power series +_realifydiag!(A::AbstractArray{<:Real}) = A +function _realifydiag!(A) + n = LinearAlgebra.checksquare(A) + for i in 1:n + @inbounds A[i,i] = real(A[i,i]) + end + return A +end +@adjoint _realifydiag!(A) = _realifydiag!(A), Δ -> (_realifydiag!(Δ),) + +_hasrealdomain(f, x) = true +_hasrealdomain(::Union{typeof.((acos,asin))...}, x) = all(x -> -1 ≤ x ≤ 1, x) +_hasrealdomain(::typeof(acosh), x) = all(x -> x ≥ 1, x) +_hasrealdomain(::Union{typeof.((log,sqrt,^))...}, x) = all(x -> x ≥ 0, x) + +_process_series_eigvals(f, λ) = _hasrealdomain(f, λ) ? λ : complex.(λ) + +_process_series_matrix(f, fA, A, fλ) = fA +_process_series_matrix(f, fA, ::LinearAlgebra.HermOrSym{<:Real}, fλ) = Symmetric(fA) +_process_series_matrix(f, fA, ::Hermitian{<:Complex}, ::AbstractVector{<:Real}) = + Hermitian(_realifydiag!(fA)) +_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Real}, fλ) = Hermitian(fA) +_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Real}, ::AbstractVector{<:Complex}) = fA +_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Complex}, ::AbstractVector{<:Complex}) = fA + +# Compute function on eigvals, thunks for conjugates of 1st and 2nd derivatives, +# and function to pull back adjoints to args +function _pullback_series_func_scalar(f, λ, args...) + compλ = _process_series_eigvals(f, λ) + fλ, fback = Zygote.pullback((x,args...) -> f.(x, args...), compλ, args...) + n = length(λ) + return (fλ, + ()->fback(ones(n))[1], + ()->nothing, # TODO: add 2nd deriv + isempty(args) ? _ -> () : f̄λ -> tail(fback(f̄λ))) +end + +function _pullback_series_func_scalar(f::typeof(^), λ, p) + compλ = _process_series_eigvals(f, λ) + r, powλ = isinteger(p) ? (Integer(p), λ) : (p, compλ) + fλ = powλ .^ r + return (fλ, + ()->conj.(r .* powλ .^ (r - 1)), + ()->conj.((r * (r - 1)) .* powλ .^ (r - 2)), + f̄λ -> (dot(fλ .* log.(compλ), f̄λ),)) +end + +function _pullback_series_func_scalar(f::typeof(exp), λ) + expλ = exp.(λ) + return expλ, ()->expλ, ()->expλ, _ -> () +end + +_apply_series_func(f, A, args...) = f(A, args...) + +@adjoint function _apply_series_func(f, A, args...) + hasargs = !isempty(args) + n = LinearAlgebra.checksquare(A) + λ, U = eigen(A) + fλ, dfthunk, d²fthunk, argsback = _pullback_series_func_scalar(f, λ, args...) + fΛ = Diagonal(fλ) + fA = U * fΛ * U' + Ω = _process_series_matrix(f, fA, A, fλ) + return Ω, function (f̄A) + f̄Λ = U' * f̄A * U + ārgs = hasargs ? argsback(diag(f̄Λ)) : () + P = _pairdiffquotmat(f, n, λ, conj(fλ), dfthunk(), d²fthunk()) + Ā = U * (P .* f̄Λ) * U' + return (nothing, Ā, ārgs...) + end +end + +_hermsympow(A::Symmetric, p::Integer) = LinearAlgebra.sympow(A, p) +_hermsympow(A::Hermitian, p::Integer) = A^p + +@adjoint function _hermsympow(A::Hermitian, p::Integer) + if p < 0 + B, back = Zygote.pullback(A->Base.power_by_squaring(inv(A), -p), A) + else + B, back = Zygote.pullback(A->Base.power_by_squaring(A, p), A) + end + Ω = Hermitian(_realifydiag!(B)) + return Ω, function (Ω̄) + B̄ = _hermitian_back(Ω̄, 'U') + Ā = back(B̄)[1] + return (Ā, nothing) + end +end + +_pullback(cx::AContext, ::typeof(^), A::LinearAlgebra.HermOrSym{<:Real}, p::Integer) = + _pullback(cx, _hermsympow, A, p) +_pullback(cx::AContext, ::typeof(^), A::Symmetric{<:Complex}, p::Integer) = + _pullback(cx, _hermsympow, A, p) +_pullback(cx::AContext, ::typeof(^), A::Hermitian{<:Complex}, p::Integer) = + _pullback(cx, _hermsympow, A, p) + +function _pullback(cx::AContext, + f::typeof(^), + A::LinearAlgebra.RealHermSymComplexHerm, + p::Real) + return _pullback(cx, (A, p) -> _apply_series_func(f, A, p), A, p) +end + +for func in (:exp, :log, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh, :sqrt) + @eval begin + function _pullback(cx::AContext, + f::typeof($func), + A::LinearAlgebra.RealHermSymComplexHerm) + return _pullback(cx, A -> _apply_series_func(f, A), A) + end + end +end + +@adjoint function sincos(A::LinearAlgebra.RealHermSymComplexHerm) + n = LinearAlgebra.checksquare(A) + λ, U = eigen(A) + sλ, cλ = Buffer(λ), Buffer(λ) + for i in Base.OneTo(n) + @inbounds sλ[i], cλ[i] = sincos(λ[i]) + end + sinλ, cosλ = copy(sλ), copy(cλ) + sinA, cosA = U * Diagonal(sinλ) * U', U * Diagonal(cosλ) * U' + Ω, processback = Zygote.pullback(sinA, cosA) do s,c + return (_process_series_matrix(sin, s, A, λ), + _process_series_matrix(cos, c, A, λ)) + end + return Ω, function (Ω̄) + s̄inA, c̄osA = processback(Ω̄) + s̄inΛ, c̄osΛ = U' * s̄inA * U, U' * c̄osA * U + PS = _pairdiffquotmat(sin, n, λ, sinλ, cosλ, -sinλ) + PC = _pairdiffquotmat(cos, n, λ, cosλ, -sinλ, -cosλ) + Ā = U * (PS .* s̄inΛ .+ PC .* c̄osΛ) * U' + return (Ā,) + end +end + Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix) # x is a squre matrix checked by tr, # so we could just use Eye(size(x, 1)) diff --git a/src/lib/number.jl b/src/lib/number.jl index 5a0dcab87..8e293b49a 100644 --- a/src/lib/number.jl +++ b/src/lib/number.jl @@ -49,6 +49,9 @@ end (s, c), ((s̄, c̄),) -> (s̄*c - c̄*s,) end +@adjoint acosh(x::Complex) = + acosh(x), Δ -> (Δ * conj(inv(sqrt(x - 1) * sqrt(x + 1))),) + @adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, - c̄ * a // b // b)) @nograd floor, ceil, trunc, round, hash diff --git a/test/gradcheck.jl b/test/gradcheck.jl index 499e3578e..a069ff207 100644 --- a/test/gradcheck.jl +++ b/test/gradcheck.jl @@ -25,10 +25,25 @@ gradcheck(f, xs...) = gradtest(f, xs::AbstractArray...) = gradcheck((xs...) -> sum(sin.(f(xs...))), xs...) gradtest(f, dims...) = gradtest(f, rand.(Float64, dims)...) +# utilities for using gradcheck with complex matrices +_splitreim(A) = (real(A),) +_splitreim(A::AbstractArray{<:Complex}) = reim(A) + +_joinreim(A, B) = complex.(A, B) +_joinreim(A) = A + +function _dropimaggrad(A) + back(Δ) = real(Δ) + back(Δ::Nothing) = nothing + return Zygote.hook(back, A) +end + Random.seed!(0) @test gradient(//, 2, 3) === (1//3, -2//9) +@test gradtest((a,b)->sum(reim(acosh(complex(a[1], b[1])))), [-2.0], [1.0]) + @test gradtest((x, W, b) -> identity.(W*x .+ b), 5, (2,5), 2) @test gradtest((x, W, b) -> identity.(W*x .+ b), (5,3), (2,5), 2) @@ -519,113 +534,266 @@ end for i = 1:5 A = randn(rng, N, N) @test gradtest(exp, A) + + @testset "similar eigenvalues" begin + λ, V = eigen(A) + λ[1] = λ[3] + sqrt(eps(real(eltype(λ)))) / 10 + A2 = real.(V * Diagonal(λ) / V) + @test gradtest(exp, A2) + end end end @testset "complex dense" begin rng, N = MersenneTwister(6865931), 8 for i = 1:5 - A = randn(rng, N, N) - B = randn(rng, N, N) - @test gradcheck(A, B) do a,b + A = randn(rng, ComplexF64, N, N) + @test gradcheck(reim(A)...) do a,b c = complex.(a, b) d = exp(c) return sum(real.(d) + 2 .* imag.(d)) end + + @testset "similar eigenvalues" begin + λ, V = eigen(A) + λ[1] = λ[3] + sqrt(eps(real(eltype(λ)))) / 10 + A2 = V * Diagonal(λ) / V + @test gradcheck(reim(A2)...) do a,b + c = complex.(a, b) + d = exp(c) + return sum(real.(d) + 2 .* imag.(d)) + end + end end end end +_hermsymtype(::Type{<:Symmetric}) = Symmetric +_hermsymtype(::Type{<:Hermitian}) = Hermitian + +function _gradtest_hermsym(f, ST, A) + gradtest(_splitreim(collect(A))...) do (args...) + B = f(ST(_joinreim(_dropimaggrad.(args)...))) + return sum(_splitreim(B)) + end +end + @testset "eigen(::RealHermSymComplexHerm)" begin - @testset "eigen(::Symmetric{<:Real})" begin - rng, N = MersenneTwister(123), 7 - A = Symmetric(randn(rng, N, N)) - @test gradtest(collect(A)) do (x) - d, Q = eigen(Symmetric(x)) - return Q * Diagonal(exp.(d)) * transpose(Q) + MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64}) + rng, N = MersenneTwister(123), 7 + @testset "eigen(::$MT)" for MT in MTs + T = eltype(MT) + ST = _hermsymtype(MT) + + A = ST(randn(rng, T, N, N)) + U = eigvecs(A) + + @test _gradtest_hermsym(ST, A) do (A) + d, U = eigen(A) + return U * Diagonal(exp.(d)) * U' end + y = Zygote.pullback(eigen, A)[1] y2 = eigen(A) @test y.values ≈ y2.values @test y.vectors ≈ y2.vectors + @testset "low rank" begin - U = eigvecs(A) A2 = Symmetric(U * Diagonal([randn(rng), zeros(N-1)...]) * U') - @test_broken gradtest(collect(A2)) do (x) - d, Q = eigen(Symmetric(x)) - return Q * Diagonal(exp.(d)) * transpose(Q) + @test_broken _gradtest_hermsym(ST, A2) do (A) + d, U = eigen(A) + return U * Diagonal(exp.(d)) * U' end end end +end - @testset "eigen(::Hermitian{<:Real})" begin - rng, N = MersenneTwister(456), 7 - A = Hermitian(randn(rng, N, N)) - @test gradtest(collect(A)) do (x) - d, Q = eigen(Hermitian(x)) - return Q * Diagonal(exp.(d)) * transpose(Q) - end - y = Zygote.pullback(eigen, A)[1] - y2 = eigen(A) - @test y.values ≈ y2.values - @test y.vectors ≈ y2.vectors - @testset "low rank" begin - U = eigvecs(A) - A2 = Hermitian(U * Diagonal([randn(rng), zeros(N-1)...]) * U') - @test_broken gradtest(collect(A2)) do (x) - d, Q = eigen(Hermitian(x)) - return Q * Diagonal(exp.(d)) * transpose(Q) +@testset "eigvals(::RealHermSymComplexHerm)" begin + MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64}) + rng, N = MersenneTwister(123), 7 + @testset "eigvals(::$MT)" for MT in MTs + T = eltype(MT) + ST = _hermsymtype(MT) + + A = ST(randn(rng, T, N, N)) + @test _gradtest_hermsym(A ->eigvals(A), ST, A) + @test Zygote.pullback(eigvals, A)[1] ≈ eigvals(A) + end +end + +_randmatunitary(rng, T, n) = qr(randn(rng, T, n, n)).Q +function _randvectorin(rng, n, r) + l, u = r + isinf(l) && isinf(u) && return randn(rng, n) + isinf(l) && return rand(rng, n) .+ (u - 1) + isinf(u) && return rand(rng, n) .+ l + return rand(rng, n) .* (u - l) .+ l +end + +realdomainrange(::Any) = (Inf, Inf) +realdomainrange(::Union{typeof.((acos,asin,atanh))...}) = (-1, 1) +realdomainrange(::typeof(acosh)) = (1, Inf) +realdomainrange(::Union{typeof.((log,sqrt,^))...}) = (0, Inf) + +function _randmatseries(rng, f, T, n, domain::Type{Real}) + U = _randmatunitary(rng, T, n) + λ = _randvectorin(rng, n, realdomainrange(f)) + return U * Diagonal(λ) * U' +end + +function _randmatseries(rng, f, T, n, domain::Type{Complex}) + U = _randmatunitary(rng, T, n) + r = realdomainrange(f) + r == (Inf, Inf) && return nothing + λ = _randvectorin(rng, n, r) + λ[end] -= 2 + return U * Diagonal(λ) * U' +end + +_randmatseries(rng, ::typeof(atanh), T, n, domain::Type{Complex}) = nothing + +@testset "Hermitian/Symmetric power series functions" begin + MTs = (Symmetric{Float64}, Hermitian{Float64}, Hermitian{ComplexF64}) + rng, N = MersenneTwister(123), 7 + domains = (Real, Complex) + @testset "$func(::RealHermSymComplexHerm)" for func in (:exp, :log, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh, :sqrt) + f = eval(func) + @testset "$func(::$MT)" for MT in MTs + T = eltype(MT) + ST = _hermsymtype(MT) + @testset "domain $domain" for domain in domains + preA = _randmatseries(rng, f, T, N, domain) + preA === nothing && continue + A = ST(preA) + λ, U = eigen(A) + + @test _gradtest_hermsym(f, ST, A) + + y = Zygote.pullback(f, A)[1] + y2 = f(A) + @test y ≈ y2 + @test typeof(y) == typeof(y2) + + @testset "similar eigenvalues" begin + λ[1] = λ[3] + sqrt(eps(eltype(λ))) / 10 + A2 = U * Diagonal(λ) * U' + @test _gradtest_hermsym(f, ST, A2) + end + + if f ∉ (log, sqrt) # only defined for invertible matrices + @testset "low rank" begin + A3 = U * Diagonal([rand(rng), zeros(N-1)...]) * U' + @test _gradtest_hermsym(f, ST, A3) + end + end end end end - @testset "eigen(::Hermitian{<:Complex})" begin - rng, N = MersenneTwister(789), 7 - A = Hermitian(randn(rng, ComplexF64, N, N)) - @test gradtest(reim(collect(A))...) do a,b - d, U = eigen(Hermitian(complex.(a, b))) - X = U * Diagonal(exp.(d)) * U' - return real.(X) .+ imag.(X) - end - y = Zygote.pullback(eigen, A)[1] - y2 = eigen(A) - @test y.values ≈ y2.values - @test y.vectors ≈ y2.vectors - @testset "low rank" begin - U = eigvecs(A) - A2 = Hermitian(U * Diagonal([randn(rng), zeros(N-1)...]) * U') - @test_broken gradtest(reim(collect(A2))...) do a,b - d, U = eigen(Hermitian(complex.(a, b))) - X = U * Diagonal(exp.(d)) * U' - return real.(X) .+ imag.(X) + @testset "sincos(::RealHermSymComplexHerm)" begin + @testset "sincos(::$MT)" for MT in MTs + T = eltype(MT) + ST = _hermsymtype(MT) + A = ST(_randmatseries(rng, sincos, T, N, Real)) + λ, U = eigen(A) + + @test gradtest(_splitreim(collect(A))...) do (args...) + S,C = sincos(ST(_joinreim(_dropimaggrad.(args)...))) + return vcat(vec.(_splitreim(S))..., vec.(_splitreim(C))...) + end + + y = Zygote.pullback(sincos, A)[1] + y2 = sincos(A) + @test y[1] ≈ y2[1] + @test typeof(y[1]) == typeof(y2[1]) + @test y[2] ≈ y2[2] + @test typeof(y[2]) == typeof(y2[2]) + + @testset "similar eigenvalues" begin + λ[1] = λ[3] + sqrt(eps(eltype(λ))) / 10 + A2 = U * Diagonal(λ) * U' + @test gradtest(_splitreim(collect(A2))...) do (args...) + S,C = sincos(ST(_joinreim(_dropimaggrad.(args)...))) + return vcat(vec.(_splitreim(S))..., vec.(_splitreim(C))...) + end + end + + @testset "low rank" begin + A3 = U * Diagonal([rand(rng), zeros(N-1)...]) * U' + @test gradtest(_splitreim(collect(A3))...) do (args...) + S,C = sincos(ST(_joinreim(_dropimaggrad.(args)...))) + return vcat(vec.(_splitreim(S))..., vec.(_splitreim(C))...) + end end end end -end -@testset "eigvals(::RealHermSymComplexHerm)" begin - @testset "eigvals(::Symmetric{<:Real})" begin - rng, N = MersenneTwister(123), 7 - A = Symmetric(randn(rng, N, N)) - @test gradtest(x->eigvals(Symmetric(x)), collect(A)) - @test Zygote.pullback(eigvals, A)[1] ≈ eigvals(A) + @testset "^(::RealHermSymComplexHerm, p::Real)" begin + @testset for p in (-1.0, -0.5, 0.5, 1.0, 1.5) + @testset "^(::$MT, $p)" for MT in MTs + T = eltype(MT) + ST = _hermsymtype(MT) + @testset "domain $domain" for domain in domains + A = ST(_randmatseries(rng, ^, T, N, domain)) + λ, U = eigen(A) + + @test gradcheck(_splitreim(collect(A))..., [p]) do (args...) + p = _dropimaggrad(args[end][1]) + A = ST(_joinreim(_dropimaggrad.(args[1:end-1])...)) + B = A^p + return abs(sum(sin.(B))) + end + + y = Zygote.pullback(^, A, p)[1] + y2 = A^p + @test y ≈ y2 + @test typeof(y) == typeof(y2) + + @testset "similar eigenvalues" begin + λ[1] = λ[3] + sqrt(eps(eltype(λ))) / 10 + A2 = U * Diagonal(λ) * U' + @test gradcheck(_splitreim(collect(A2))..., [p]) do (args...) + p = _dropimaggrad(args[end][1]) + A = ST(_joinreim(_dropimaggrad.(args[1:end-1])...)) + B = A^p + return abs(sum(sin.(B))) + end + end + end + end + end end +end - @testset "eigvals(::Hermitian{<:Real})" begin - rng, N = MersenneTwister(456), 7 - A = Hermitian(randn(rng, N, N)) - @test gradtest(x->eigvals(Hermitian(x)), collect(A)) - @test Zygote.pullback(eigvals, A)[1] ≈ eigvals(A) - end +@testset "^(::Union{Symmetric,Hermitian}, p::Integer)" begin + MTs = (Symmetric{Float64}, Symmetric{ComplexF64}, + Hermitian{Float64}, Hermitian{ComplexF64}) + rng, N = MersenneTwister(123), 7 + @testset for p in -3:3 + @testset "^(::$MT, $p)" for MT in MTs + T = eltype(MT) + ST = _hermsymtype(MT) + A = ST(randn(rng, T, N, N)) + + if p == 0 + @test gradient(_splitreim(collect(A))...) do (args...) + A = ST(_joinreim(_dropimaggrad.(args)...)) + B = A^p + return sum(sin.(vcat(vec.(_splitreim(B))...))) + end === map(_->nothing, _splitreim(A)) + else + @test gradtest(_splitreim(collect(A))...) do (args...) + A = ST(_joinreim(_dropimaggrad.(args)...)) + B = A^p + return vcat(vec.(_splitreim(B))...) + end + end - @testset "eigvals(::Hermitian{<:Complex})" begin - rng, N = MersenneTwister(789), 7 - A, B = randn(rng, N, N), randn(rng, N, N) - @test gradtest(A, B) do a,b - c = Hermitian(complex.(a, b)) - return eigvals(c) + y = Zygote.pullback(^, A, p)[1] + y2 = A^p + @test y ≈ y2 + @test typeof(y) === typeof(y2) end - @test Zygote.pullback(eigvals, Hermitian(A))[1] ≈ eigvals(Hermitian(A)) end end