Skip to content

Commit

Permalink
Try #355:
Browse files Browse the repository at this point in the history
  • Loading branch information
bors[bot] authored Nov 9, 2019
2 parents 6ac5b5b + fb95ec7 commit 287f704
Show file tree
Hide file tree
Showing 3 changed files with 403 additions and 70 deletions.
164 changes: 163 additions & 1 deletion src/lib/array.jl
Original file line number Diff line number Diff line change
Expand Up @@ -416,6 +416,11 @@ end
return H, back
end

@adjoint convert(::Type{R}, A::LinearAlgebra.HermOrSym{T,S}) where {T,S,R<:Array} = convert(R, A),
Δ -> (nothing, convert(S, Δ),)
@adjoint Matrix(A::LinearAlgebra.HermOrSym{T,S}) where {T,S} = Matrix(A),
Δ -> (convert(S, Δ),)

@adjoint function cholesky::Real)
C = cholesky(Σ)
return C, Δ::NamedTuple->.factors[1, 1] / (2 * C.U[1, 1]),)
Expand Down Expand Up @@ -451,19 +456,39 @@ end
end
end

# Matrix of pairwise difference quotients
Base.@propagate_inbounds function _pairdiffquot(f, i, j, x, fx, dfx, d²fx = nothing)
i == j && return dfx[i]
Δx = x[i] - x[j]
T = real(eltype(x))
if d²fx === nothing
abs(Δx) sqrt(eps(T)) && return (dfx[i] + dfx[j]) / 2
else
abs(Δx) eps(T)^(1/3) && return dfx[i] - Δx / 2 * d²fx[i]
end
Δfx = fx[i] - fx[j]
return Δfx / Δx
end

Base.@propagate_inbounds function _pairdiffquotmat(f, n, x, fx, dfx, d²fx = nothing)
Δfij = (i, j)->_pairdiffquot(f, i, j, x, fx, dfx, d²fx)
return Δfij.(Base.OneTo(n), Base.OneTo(n)')
end

# Adjoint based on the Theano implementation, which uses the differential as described
# in Brančík, "Matlab programs for matrix exponential function derivative evaluation"
@adjoint exp(A::AbstractMatrix) = exp(A), function(F̄)
n = size(A, 1)
E = eigen(A)
w = E.values
ew = exp.(w)
X = [i==j ? ew[i] : (ew[i]-ew[j])/(w[i]-w[j]) for i in 1:n,j=1:n]
X = _pairdiffquotmat(exp, n, w, ew, ew, ew)
V = E.vectors
VF = factorize(V)
= (V * ((VF \' * V) .* X) / VF)'
return (Ā,)
end

@adjoint function LinearAlgebra.eigen(A::LinearAlgebra.RealHermSymComplexHerm)
dU = eigen(A)
return dU, function (Δ)
Expand All @@ -489,6 +514,143 @@ end
return d, d̄ -> (U * Diagonal(d̄) * U',)
end


# Hermitian/Symmetric matrix functions that can be written as power series
_realifydiag!(A::AbstractArray{<:Real}) = A
function _realifydiag!(A)
n = LinearAlgebra.checksquare(A)
for i in 1:n
@inbounds A[i,i] = real(A[i,i])
end
return A
end
@adjoint _realifydiag!(A) = _realifydiag!(A), Δ -> (_realifydiag!(Δ),)

_hasrealdomain(f, x) = true
_hasrealdomain(::Union{typeof.((acos,asin))...}, x) = all(x -> -1 x 1, x)
_hasrealdomain(::typeof(acosh), x) = all(x -> x 1, x)
_hasrealdomain(::Union{typeof.((log,sqrt,^))...}, x) = all(x -> x 0, x)

_process_series_eigvals(f, λ) = _hasrealdomain(f, λ) ? λ : complex.(λ)

_process_series_matrix(f, fA, A, fλ) = fA
_process_series_matrix(f, fA, ::LinearAlgebra.HermOrSym{<:Real}, fλ) = Symmetric(fA)
_process_series_matrix(f, fA, ::Hermitian{<:Complex}, ::AbstractVector{<:Real}) =
Hermitian(_realifydiag!(fA))
_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Real}, fλ) = Hermitian(fA)
_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Real}, ::AbstractVector{<:Complex}) = fA
_process_series_matrix(::typeof(^), fA, ::Hermitian{<:Complex}, ::AbstractVector{<:Complex}) = fA

# Compute function on eigvals, thunks for conjugates of 1st and 2nd derivatives,
# and function to pull back adjoints to args
function _pullback_series_func_scalar(f, λ, args...)
compλ = _process_series_eigvals(f, λ)
fλ, fback = Zygote.pullback((x,args...) -> f.(x, args...), compλ, args...)
n = length(λ)
return (fλ,
()->fback(ones(n))[1],
()->nothing, # TODO: add 2nd deriv
isempty(args) ? _ -> () : f̄λ -> tail(fback(f̄λ)))
end

function _pullback_series_func_scalar(f::typeof(^), λ, p)
compλ = _process_series_eigvals(f, λ)
r, powλ = isinteger(p) ? (Integer(p), λ) : (p, compλ)
= powλ .^ r
return (fλ,
()->conj.(r .* powλ .^ (r - 1)),
()->conj.((r * (r - 1)) .* powλ .^ (r - 2)),
f̄λ -> (dot(fλ .* log.(compλ), f̄λ),))
end

function _pullback_series_func_scalar(f::typeof(exp), λ)
expλ = exp.(λ)
return expλ, ()->expλ, ()->expλ, _ -> ()
end

_apply_series_func(f, A, args...) = f(A, args...)

@adjoint function _apply_series_func(f, A, args...)
hasargs = !isempty(args)
n = LinearAlgebra.checksquare(A)
λ, U = eigen(A)
fλ, dfthunk, d²fthunk, argsback = _pullback_series_func_scalar(f, λ, args...)
= Diagonal(fλ)
fA = U ** U'
Ω = _process_series_matrix(f, fA, A, fλ)
return Ω, function (f̄A)
f̄Λ = U' * f̄A * U
ārgs = hasargs ? argsback(diag(f̄Λ)) : ()
P = _pairdiffquotmat(f, n, λ, conj(fλ), dfthunk(), d²fthunk())
= U * (P .* f̄Λ) * U'
return (nothing, Ā, ārgs...)
end
end

_hermsympow(A::Symmetric, p::Integer) = LinearAlgebra.sympow(A, p)
_hermsympow(A::Hermitian, p::Integer) = A^p

@adjoint function _hermsympow(A::Hermitian, p::Integer)
if p < 0
B, back = Zygote.pullback(A->Base.power_by_squaring(inv(A), -p), A)
else
B, back = Zygote.pullback(A->Base.power_by_squaring(A, p), A)
end
Ω = Hermitian(_realifydiag!(B))
return Ω, function (Ω̄)
= _hermitian_back(Ω̄, 'U')
= back(B̄)[1]
return (Ā, nothing)
end
end

_pullback(cx::AContext, ::typeof(^), A::LinearAlgebra.HermOrSym{<:Real}, p::Integer) =
_pullback(cx, _hermsympow, A, p)
_pullback(cx::AContext, ::typeof(^), A::Symmetric{<:Complex}, p::Integer) =
_pullback(cx, _hermsympow, A, p)
_pullback(cx::AContext, ::typeof(^), A::Hermitian{<:Complex}, p::Integer) =
_pullback(cx, _hermsympow, A, p)

function _pullback(cx::AContext,
f::typeof(^),
A::LinearAlgebra.RealHermSymComplexHerm,
p::Real)
return _pullback(cx, (A, p) -> _apply_series_func(f, A, p), A, p)
end

for func in (:exp, :log, :cos, :sin, :tan, :cosh, :sinh, :tanh, :acos, :asin, :atan, :acosh, :asinh, :atanh, :sqrt)
@eval begin
function _pullback(cx::AContext,
f::typeof($func),
A::LinearAlgebra.RealHermSymComplexHerm)
return _pullback(cx, A -> _apply_series_func(f, A), A)
end
end
end

@adjoint function sincos(A::LinearAlgebra.RealHermSymComplexHerm)
n = LinearAlgebra.checksquare(A)
λ, U = eigen(A)
sλ, cλ = Buffer(λ), Buffer(λ)
for i in Base.OneTo(n)
@inbounds sλ[i], cλ[i] = sincos(λ[i])
end
sinλ, cosλ = copy(sλ), copy(cλ)
sinA, cosA = U * Diagonal(sinλ) * U', U * Diagonal(cosλ) * U'
Ω, processback = Zygote.pullback(sinA, cosA) do s,c
return (_process_series_matrix(sin, s, A, λ),
_process_series_matrix(cos, c, A, λ))
end
return Ω, function (Ω̄)
s̄inA, c̄osA = processback(Ω̄)
s̄inΛ, c̄osΛ = U' * s̄inA * U, U' * c̄osA * U
PS = _pairdiffquotmat(sin, n, λ, sinλ, cosλ, -sinλ)
PC = _pairdiffquotmat(cos, n, λ, cosλ, -sinλ, -cosλ)
= U * (PS .* s̄inΛ .+ PC .* c̄osΛ) * U'
return (Ā,)
end
end

Zygote.@adjoint function LinearAlgebra.tr(x::AbstractMatrix)
# x is a squre matrix checked by tr,
# so we could just use Eye(size(x, 1))
Expand Down
3 changes: 3 additions & 0 deletions src/lib/number.jl
Original file line number Diff line number Diff line change
Expand Up @@ -49,6 +49,9 @@ end
(s, c), ((s̄, c̄),) -> (s̄*c -*s,)
end

@adjoint acosh(x::Complex) =
acosh(x), Δ ->* conj(inv(sqrt(x - 1) * sqrt(x + 1))),)

@adjoint a // b = (a // b, c̄ -> (c̄ * 1//b, -* a // b // b))

@nograd floor, ceil, trunc, round, hash
Expand Down
Loading

0 comments on commit 287f704

Please sign in to comment.