Skip to content

Commit

Permalink
Merge branch 'master' into eigen
Browse files Browse the repository at this point in the history
  • Loading branch information
sethaxen authored Dec 9, 2020
2 parents 54a8e95 + 80443b1 commit 6a519a4
Show file tree
Hide file tree
Showing 5 changed files with 126 additions and 231 deletions.
6 changes: 5 additions & 1 deletion src/ChainRules.jl
Original file line number Diff line number Diff line change
Expand Up @@ -58,8 +58,12 @@ function __init__()
include("rulesets/packages/NaNMath.jl")
end

# Note: drop SpecialFunctions dependency in next breaking release
# https://github.com/JuliaDiff/ChainRules.jl/issues/319
@require SpecialFunctions="276daf66-3868-5448-9aa4-cd146d93841b" begin
include("rulesets/packages/SpecialFunctions.jl")
if !isdefined(SpecialFunctions, :ChainRulesCore)
include("rulesets/packages/SpecialFunctions.jl")
end
end
end

Expand Down
229 changes: 60 additions & 169 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -202,20 +202,69 @@ end
##### `cholesky`
#####

function rrule(::typeof(cholesky), X::AbstractMatrix{<:Real})
F = cholesky(X)
function cholesky_pullback::Composite)
∂X = if F.uplo === 'U'
chol_blocked_rev.U, F.U, 25, true)
else
chol_blocked_rev.L, F.L, 25, false)
end
return (NO_FIELDS, ∂X)
function rrule(::typeof(cholesky), A::Real, uplo::Symbol=:U)
C = cholesky(A, uplo)
function cholesky_pullback(ΔC::Composite)
return NO_FIELDS, ΔC.factors[1, 1] / (2 * C.U[1, 1]), DoesNotExist()
end
return C, cholesky_pullback
end

function rrule(::typeof(cholesky), A::Diagonal{<:Real}, ::Val{false}; check::Bool=true)
C = cholesky(A, Val(false); check=check)
function cholesky_pullback(ΔC::Composite)
= Diagonal(diag(ΔC.factors) .* inv.(2 .* C.factors.diag))
return NO_FIELDS, Ā, DoesNotExist()
end
return C, cholesky_pullback
end

# The appropriate cotangent is different depending upon whether A is Symmetric / Hermitian,
# or just a StridedMatrix.
# Implementation due to Seeger, Matthias, et al. "Auto-differentiating linear algebra."
function rrule(
::typeof(cholesky),
A::LinearAlgebra.HermOrSym{<:LinearAlgebra.BlasReal, <:StridedMatrix},
::Val{false};
check::Bool=true,
)
C = cholesky(A, Val(false); check=check)
function cholesky_pullback(ΔC::Composite)
Ā, U = _cholesky_pullback_shared_code(C, ΔC)
= BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)) / 2, U.data, Ā)
return NO_FIELDS, _symhermtype(A)(Ā), DoesNotExist()
end
return F, cholesky_pullback
return C, cholesky_pullback
end

function rrule(
::typeof(cholesky),
A::StridedMatrix{<:LinearAlgebra.BlasReal},
::Val{false};
check::Bool=true,
)
C = cholesky(A, Val(false); check=check)
function cholesky_pullback(ΔC::Composite)
Ā, U = _cholesky_pullback_shared_code(C, ΔC)
= BLAS.trsm!('R', 'U', 'C', 'N', one(eltype(Ā)), U.data, Ā)
idx = diagind(Ā)
@views Ā[idx] .= real.(Ā[idx]) ./ 2
return (NO_FIELDS, UpperTriangular(Ā), DoesNotExist())
end
return C, cholesky_pullback
end

function _cholesky_pullback_shared_code(C, ΔC)
U = C.U
= ΔC.U
= similar(U.data)
= mul!(Ā, Ū, U')
= LinearAlgebra.copytri!(Ā, 'U', true)
= ldiv!(U, Ā)
return Ā, U
end

function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
function rrule(::typeof(getproperty), F::T, x::Symbol) where {T <: Cholesky}
function getproperty_cholesky_pullback(Ȳ)
C = Composite{T}
∂F = if x === :U
Expand All @@ -235,161 +284,3 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: Cholesky
end
return getproperty(F, x), getproperty_cholesky_pullback
end

# See "Differentiation of the Cholesky decomposition" (Murray 2016), pages 5-9 in particular,
# for derivations. Here we're implementing the algorithms and their transposes.

"""
level2partition(A::AbstractMatrix, j::Integer, upper::Bool)
Returns views to various bits of the lower triangle of `A` according to the
`level2partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
the transposed views are returned from the upper triangle of `A`.
[1]: "Differentiation of the Cholesky decomposition", Murray 2016
"""
function level2partition(A::AbstractMatrix, j::Integer, upper::Bool)
n = checksquare(A)
@boundscheck checkbounds(1:n, j)
if upper
r = view(A, 1:j-1, j)
d = view(A, j, j)
B = view(A, 1:j-1, j+1:n)
c = view(A, j, j+1:n)
else
r = view(A, j, 1:j-1)
d = view(A, j, j)
B = view(A, j+1:n, 1:j-1)
c = view(A, j+1:n, j)
end
return r, d, B, c
end

"""
level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
Returns views to various bits of the lower triangle of `A` according to the
`level3partition` procedure defined in [1] if `upper` is `false`. If `upper` is `true` then
the transposed views are returned from the upper triangle of `A`.
[1]: "Differentiation of the Cholesky decomposition", Murray 2016
"""
function level3partition(A::AbstractMatrix, j::Integer, k::Integer, upper::Bool)
n = checksquare(A)
@boundscheck checkbounds(1:n, j)
if upper
R = view(A, 1:j-1, j:k)
D = view(A, j:k, j:k)
B = view(A, 1:j-1, k+1:n)
C = view(A, j:k, k+1:n)
else
R = view(A, j:k, 1:j-1)
D = view(A, j:k, j:k)
B = view(A, k+1:n, 1:j-1)
C = view(A, k+1:n, j:k)
end
return R, D, B, C
end

"""
chol_unblocked_rev!(Ā::AbstractMatrix, L::AbstractMatrix, upper::Bool)
Compute the reverse-mode sensitivities of the Cholesky factorization in an unblocked manner.
If `upper` is `false`, then the sensitivites are computed from and stored in the lower triangle
of `Ā` and `L` respectively. If `upper` is `true` then they are computed and stored in the
upper triangles. If at input `upper` is `false` and `tril(Ā) = L̄`, at output
`tril(Ā) = tril(Σ̄)`, where `Σ = LLᵀ`. Analogously, if at input `upper` is `true` and
`triu(Ā) = triu(Ū)`, at output `triu(Ā) = triu(Σ̄)` where `Σ = UᵀU`.
"""
function chol_unblocked_rev!(Σ̄::AbstractMatrix{T}, L::AbstractMatrix{T}, upper::Bool) where T<:Real
n = checksquare(Σ̄)
j = n
@inbounds for _ in 1:n
r, d, B, c = level2partition(L, j, upper)
r̄, d̄, B̄, c̄ = level2partition(Σ̄, j, upper)

# d̄ <- d̄ - c'c̄ / d.
d̄[1] -= dot(c, c̄) / d[1]

# [d̄ c̄'] <- [d̄ c̄'] / d.
./= d
./= d

# r̄ <- r̄ - [d̄ c̄'] [r' B']'.
= axpy!(-Σ̄[j,j], r, r̄)
= gemv!(upper ? 'n' : 'T', -one(T), B, c̄, one(T), r̄)

# B̄ <- B̄ - c̄ r.
= upper ? ger!(-one(T), r, c̄, B̄) : ger!(-one(T), c̄, r, B̄)
./= 2
j -= 1
end
return (upper ? triu! : tril!)(Σ̄)
end

function chol_unblocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, upper::Bool)
return chol_unblocked_rev!(copy(Σ̄), L, upper)
end

"""
chol_blocked_rev!(Σ̄::StridedMatrix, L::StridedMatrix, nb::Integer, upper::Bool)
Compute the sensitivities of the Cholesky factorization using a blocked, cache-friendly
procedure. `Σ̄` are the sensitivities of `L`, and will be transformed into the sensitivities
of `Σ`, where `Σ = LLᵀ`. `nb` is the block size to use. If the upper triangle has been used
to represent the factorization, that is `Σ = UᵀU` where `U := Lᵀ`, then this should be
indicated by passing `upper = true`.
"""
function chol_blocked_rev!(Σ̄::StridedMatrix{T}, L::StridedMatrix{T}, nb::Integer, upper::Bool) where T<:Real
n = checksquare(Σ̄)
tmp = Matrix{T}(undef, nb, nb)
k = n
if upper
@inbounds for _ in 1:nb:n
j = max(1, k - nb + 1)
R, D, B, C = level3partition(L, j, k, true)
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, true)

= trsm!('L', 'U', 'N', 'N', one(T), D, C̄)
gemm!('N', 'N', -one(T), R, C̄, one(T), B̄)
gemm!('N', 'T', -one(T), C, C̄, one(T), D̄)
chol_unblocked_rev!(D̄, D, true)
gemm!('N', 'T', -one(T), B, C̄, one(T), R̄)
if size(D̄, 1) == nb
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄))
gemm!('N', 'N', -one(T), R, tmp, one(T), R̄)
else
gemm!('N', 'N', -one(T), R, D̄ +', one(T), R̄)
end

k -= nb
end
return triu!(Σ̄)
else
@inbounds for _ in 1:nb:n
j = max(1, k - nb + 1)
R, D, B, C = level3partition(L, j, k, false)
R̄, D̄, B̄, C̄ = level3partition(Σ̄, j, k, false)

= trsm!('R', 'L', 'N', 'N', one(T), D, C̄)
gemm!('N', 'N', -one(T), C̄, R, one(T), B̄)
gemm!('T', 'N', -one(T), C̄, C, one(T), D̄)
chol_unblocked_rev!(D̄, D, false)
gemm!('T', 'N', -one(T), C̄, B, one(T), R̄)
if size(D̄, 1) == nb
tmp = axpy!(one(T), D̄, transpose!(tmp, D̄))
gemm!('N', 'N', -one(T), tmp, R, one(T), R̄)
else
gemm!('N', 'N', -one(T), D̄ +', R, one(T), R̄)
end

k -= nb
end
return tril!(Σ̄)
end
end

function chol_blocked_rev(Σ̄::AbstractMatrix, L::AbstractMatrix, nb::Integer, upper::Bool)
# Convert to `Matrix`s because blas functions require StridedMatrix input.
return chol_blocked_rev!(Matrix(Σ̄), Matrix(L), nb, upper)
end
113 changes: 55 additions & 58 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
@@ -1,4 +1,15 @@
using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblocked_rev
function FiniteDifferences.to_vec(C::Cholesky)
C_vec, factors_from_vec = to_vec(C.factors)
function cholesky_from_vec(v)
return Cholesky(factors_from_vec(v), C.uplo, C.info)
end
return C_vec, cholesky_from_vec
end

function FiniteDifferences.to_vec(x::Val)
Val_from_vec(v) = x
return Bool[], Val_from_vec
end

@testset "Factorizations" begin
@testset "svd" begin
Expand Down Expand Up @@ -73,6 +84,7 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
@test ChainRules._eyesubx!(copy(X)) I - X
end
end

@testset "eigendecomposition" begin
@testset "eigen/eigen!" begin
# NOTE: eigen!/eigen are not type-stable, so neither are their frule/rrule
Expand Down Expand Up @@ -204,69 +216,54 @@ using ChainRules: level2partition, level3partition, chol_blocked_rev, chol_unblo
end
end
end
@testset "cholesky" begin
@testset "the thing" begin
X = generate_well_conditioned_matrix(10)
V = generate_well_conditioned_matrix(10)
F, dX_pullback = rrule(cholesky, X)
for p in [:U, :L]
Y, dF_pullback = rrule(getproperty, F, p)
= (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y)))
(dself, dF, dp) = dF_pullback(Ȳ)
@test dself === NO_FIELDS
@test dp === DoesNotExist()

# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
# machinery from FiniteDifferences because that isn't set up to respect
# necessary special properties of the input. In the case of the Cholesky
# factorization, we need the input to be Hermitian.
ΔF = unthunk(dF)
_, dX = dX_pullback(ΔF)
X̄_ad = dot(unthunk(dX), V)
X̄_fd = _fdm(0.0) do ε
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
end
@test X̄_ad X̄_fd rtol=1e-6 atol=1e-6
end
# These tests are generally a bit tricky to write because FiniteDifferences doesn't
# have fantastic support for this stuff at the minute.
@testset "cholesky" begin
@testset "Real" begin
C = cholesky(rand() + 0.1)
ΔC = Composite{typeof(C)}((factors=rand_tangent(C.factors)))
rrule_test(cholesky, ΔC, (rand() + 0.1, randn()))
end
@testset "Diagonal{<:Real}" begin
D = Diagonal(rand(5) .+ 0.1)
C = cholesky(D)
ΔC = Composite{typeof(C)}((factors=Diagonal(randn(5))))
rrule_test(cholesky, ΔC, (D, Diagonal(randn(5))), (Val(false), nothing))
end
@testset "helper functions" begin
A = randn(5, 5)
r, d, B2, c = level2partition(A, 4, false)
R, D, B3, C = level3partition(A, 4, 4, false)
@test all(r .== R')
@test all(d .== D)
@test B2[1] == B3[1]
@test all(c .== C)

# Check that level 2 partition with `upper == true` is consistent with `false`
rᵀ, dᵀ, B2ᵀ, cᵀ = level2partition(transpose(A), 4, true)
@test r == rᵀ
@test d == dᵀ
@test B2' == B2ᵀ
@test c == cᵀ
X = generate_well_conditioned_matrix(10)
V = generate_well_conditioned_matrix(10)
F, dX_pullback = rrule(cholesky, X, Val(false))
@testset "uplo=$p" for p in [:U, :L]
Y, dF_pullback = rrule(getproperty, F, p)
= (p === :U ? UpperTriangular : LowerTriangular)(randn(size(Y)))
(dself, dF, dp) = dF_pullback(Ȳ)
@test dself === NO_FIELDS
@test dp === DoesNotExist()

# Check that level 3 partition with `upper == true` is consistent with `false`
R, D, B3, C = level3partition(A, 2, 4, false)
Rᵀ, Dᵀ, B3ᵀ, Cᵀ = level3partition(transpose(A), 2, 4, true)
@test transpose(R) == Rᵀ
@test transpose(D) == Dᵀ
@test transpose(B3) == B3ᵀ
@test transpose(C) == Cᵀ
# NOTE: We're doing Nabla-style testing here and avoiding using the `j′vp`
# machinery from FiniteDifferences because that isn't set up to respect
# necessary special properties of the input. In the case of the Cholesky
# factorization, we need the input to be Hermitian.
ΔF = unthunk(dF)
_, dX = dX_pullback(ΔF)
X̄_ad = dot(unthunk(dX), V)
X̄_fd = central_fdm(5, 1)(0.000_001) do ε
dot(Ȳ, getproperty(cholesky(X .+ ε .* V), p))
end
@test X̄_ad X̄_fd rtol=1e-4
end

A = Matrix(LowerTriangular(randn(10, 10)))
= Matrix(LowerTriangular(randn(10, 10)))
# NOTE: BLAS gets angry if we don't materialize the Transpose objects first
B = Matrix(transpose(A))
= Matrix(transpose(Ā))
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 1, false)
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 3, false)
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 5, false)
@test chol_unblocked_rev(Ā, A, false) chol_blocked_rev(Ā, A, 10, false)
@test chol_unblocked_rev(Ā, A, false) transpose(chol_unblocked_rev(B̄, B, true))
# Ensure that cotangents of cholesky(::StridedMatrix) and
# (cholesky ∘ Symmetric)(::StridedMatrix) are equal.
@testset "Symmetric" begin
X_symmetric, sym_back = rrule(Symmetric, X, :U)
C, chol_back_sym = rrule(cholesky, X_symmetric, Val(false))

@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 1, true)
@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 5, true)
@test chol_unblocked_rev(B̄, B, true) chol_blocked_rev(B̄, B, 10, true)
Δ = Composite{typeof(C)}((U=UpperTriangular(randn(size(X)))))
ΔX_symmetric = chol_back_sym(Δ)[2]
@test sym_back(ΔX_symmetric)[2] dX_pullback(Δ)[2]
end
end
end
2 changes: 0 additions & 2 deletions test/rulesets/packages/SpecialFunctions.jl
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
using SpecialFunctions

@testset "SpecialFunctions" for x in (1.0, -1.0, 0.0, 0.5, 10.0, -17.1, 1.5 + 0.7im)
test_scalar(SpecialFunctions.erf, x)
test_scalar(SpecialFunctions.erfc, x)
Expand Down
Loading

0 comments on commit 6a519a4

Please sign in to comment.