Skip to content

Commit

Permalink
Merge pull request #390 from JuliaDiff/mz/svd2
Browse files Browse the repository at this point in the history
Fix Composite of SVD
  • Loading branch information
mzgubic authored Mar 25, 2021
2 parents 76ef95c + 9e72757 commit 67c106f
Show file tree
Hide file tree
Showing 3 changed files with 24 additions and 44 deletions.
2 changes: 1 addition & 1 deletion Project.toml
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
name = "ChainRules"
uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2"
version = "0.7.54"
version = "0.7.55"

[deps]
ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4"
Expand Down
7 changes: 3 additions & 4 deletions src/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ function rrule(::typeof(svd), X::AbstractMatrix{<:Real})
F = svd(X)
function svd_pullback::Composite)
# `getproperty` on `Composite`s ensures we have no thunks.
∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V)
∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt')
return (NO_FIELDS, ∂X)
end
return F, svd_pullback
Expand All @@ -221,10 +221,9 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD
elseif x === :S
C(S=Ȳ,)
elseif x === :V
C(V=Ȳ,)
C(Vt=Ȳ',)
elseif x === :Vt
# TODO: https://github.com/JuliaDiff/ChainRules.jl/issues/106
throw(ArgumentError("Vt is unsupported; use V and transpose the result"))
C(Vt=Ȳ,)
end
return NO_FIELDS, ∂F, DoesNotExist()
end
Expand Down
59 changes: 20 additions & 39 deletions test/rulesets/LinearAlgebra/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -79,41 +79,39 @@ end
end
end
@testset "svd" begin
for n in [4, 6, 10], m in [3, 5, 10]
X = randn(n, m)
F, dX_pullback = rrule(svd, X)
for p in [:U, :S, :V]
Y, dF_pullback = rrule(getproperty, F, p)
= randn(size(Y)...)

dself1, dF, dp = dF_pullback(Ȳ)
@test dself1 === NO_FIELDS
@test dp === DoesNotExist()

dself2, dX = dX_pullback(dF)
@test dself2 === NO_FIELDS
X̄_ad = unthunk(dX)
X̄_fd = only(j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X))
@test all(isapprox.(X̄_ad, X̄_fd; rtol=1e-6, atol=1e-6))
for n in [4, 6, 10], m in [3, 5, 9]
@testset "($n x $m) svd" begin
X = randn(n, m)
@show X
test_rrule(svd, X; atol=1e-6, rtol=1e-6)
end
@testset "Vt" begin
Y, dF_pullback = rrule(getproperty, F, :Vt)
= randn(size(Y)...)
@test_throws ArgumentError dF_pullback(Ȳ)
end

for n in [4, 6, 10], m in [3, 5, 10]
@testset "($n x $m) getproperty" begin
X = randn(n, m)
F = svd(X)
rand_adj = adjoint(rand(reverse(size(F.V))...))

test_rrule(getproperty, F, :U nothing; check_inferred=false)
test_rrule(getproperty, F, :S nothing; check_inferred=false)
test_rrule(getproperty, F, :Vt nothing; check_inferred=false)
test_rrule(getproperty, F, :V nothing; check_inferred=false, output_tangent=rand_adj)
end
end

@testset "Thunked inputs" begin
X = randn(4, 3)
F, dX_pullback = rrule(svd, X)
for p in [:U, :S, :V]
for p in [:U, :S, :V, :Vt]
Y, dF_pullback = rrule(getproperty, F, p)
= randn(size(Y)...)

_, dF_unthunked, _ = dF_pullback(Ȳ)

# helper to let us check how things are stored.
backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p)
p_access = p == :V ? :Vt : p
backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p_access)
@assert !(backing_field(dF_unthunked, p) isa AbstractThunk)

dF_thunked = map(f->Thunk(()->f), dF_unthunked)
Expand All @@ -126,23 +124,6 @@ end
end
end

@testset "+" begin
X = [1.0 2.0; 3.0 4.0; 5.0 6.0]
F, dX_pullback = rrule(svd, X)
= Composite{typeof(F)}(U=zeros(3, 2), S=zeros(2), V=zeros(2, 2))
for p in [:U, :S, :V]
Y, dF_pullback = rrule(getproperty, F, p)
= ones(size(Y)...)
dself, dF, dp = dF_pullback(Ȳ)
@test dself === NO_FIELDS
@test dp === DoesNotExist()
+= dF
end
@test.U ones(3, 2) atol=1e-6
@test.S ones(2) atol=1e-6
@test.V ones(2, 2) atol=1e-6
end

@testset "Helper functions" begin
X = randn(10, 10)
Y = randn(10, 10)
Expand Down

0 comments on commit 67c106f

Please sign in to comment.