From 0951ffacedf350c3f15ad87e9fe91049855b1a1b Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 23 Mar 2021 20:53:47 +0000 Subject: [PATCH 01/10] adapt tests to new powers --- test/rulesets/LinearAlgebra/factorization.jl | 16 ++++++---------- 1 file changed, 6 insertions(+), 10 deletions(-) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index c70d0684b..b75731071 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -82,7 +82,7 @@ end for n in [4, 6, 10], m in [3, 5, 10] X = randn(n, m) F, dX_pullback = rrule(svd, X) - for p in [:U, :S, :V] + for p in [:U, :S, :V, :Vt] Y, dF_pullback = rrule(getproperty, F, p) Ȳ = randn(size(Y)...) @@ -96,24 +96,20 @@ end X̄_fd = only(j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)) @test all(isapprox.(X̄_ad, X̄_fd; rtol=1e-6, atol=1e-6)) end - @testset "Vt" begin - Y, dF_pullback = rrule(getproperty, F, :Vt) - Ȳ = randn(size(Y)...) - @test_throws ArgumentError dF_pullback(Ȳ) - end end @testset "Thunked inputs" begin X = randn(4, 3) F, dX_pullback = rrule(svd, X) - for p in [:U, :S, :V] + for p in [:U, :S, :V, :Vt] Y, dF_pullback = rrule(getproperty, F, p) Ȳ = randn(size(Y)...) _, dF_unthunked, _ = dF_pullback(Ȳ) # helper to let us check how things are stored. - backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p) + p_access = p == :V ? :Vt : p + backing_field(c, p) = getproperty(ChainRulesCore.backing(c), p_access) @assert !(backing_field(dF_unthunked, p) isa AbstractThunk) dF_thunked = map(f->Thunk(()->f), dF_unthunked) @@ -130,7 +126,7 @@ end X = [1.0 2.0; 3.0 4.0; 5.0 6.0] F, dX_pullback = rrule(svd, X) X̄ = Composite{typeof(F)}(U=zeros(3, 2), S=zeros(2), V=zeros(2, 2)) - for p in [:U, :S, :V] + for p in [:U, :S, :V, :Vt] Y, dF_pullback = rrule(getproperty, F, p) Ȳ = ones(size(Y)...) dself, dF, dp = dF_pullback(Ȳ) @@ -140,7 +136,7 @@ end end @test X̄.U ≈ ones(3, 2) atol=1e-6 @test X̄.S ≈ ones(2) atol=1e-6 - @test X̄.V ≈ ones(2, 2) atol=1e-6 + @test X̄.Vt ≈ 2 * ones(2, 2) atol=1e-6 # * 2 because V and Vt both accumulate to Vt end @testset "Helper functions" begin From 082019c57659fa13672d1f416d70d42516ee16a4 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 23 Mar 2021 20:54:06 +0000 Subject: [PATCH 02/10] add new implementation --- src/rulesets/LinearAlgebra/factorization.jl | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 2ec9e2c97..5f00320a6 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -207,7 +207,7 @@ function rrule(::typeof(svd), X::AbstractMatrix{<:Real}) F = svd(X) function svd_pullback(Ȳ::Composite) # `getproperty` on `Composite`s ensures we have no thunks. - ∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.V) + ∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt') return (NO_FIELDS, ∂X) end return F, svd_pullback @@ -221,10 +221,9 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD elseif x === :S C(S=Ȳ,) elseif x === :V - C(V=Ȳ,) + C(Vt=Ȳ',) elseif x === :Vt - # TODO: https://github.com/JuliaDiff/ChainRules.jl/issues/106 - throw(ArgumentError("Vt is unsupported; use V and transpose the result")) + C(Vt=Ȳ,) end return NO_FIELDS, ∂F, DoesNotExist() end From 0ab70af591aa2f6bd43a8d2dc90e10ea43eb98da Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 23 Mar 2021 20:54:17 +0000 Subject: [PATCH 03/10] bump patch --- Project.toml | 1 + 1 file changed, 1 insertion(+) diff --git a/Project.toml b/Project.toml index 865448158..6d9e4db24 100644 --- a/Project.toml +++ b/Project.toml @@ -5,6 +5,7 @@ version = "0.7.54" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" From 5ba3dbdd6d8dede33f12ab855937c61ab06c8d9d Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Tue, 23 Mar 2021 21:02:49 +0000 Subject: [PATCH 04/10] bump patch --- Project.toml | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/Project.toml b/Project.toml index 6d9e4db24..0144bcec2 100644 --- a/Project.toml +++ b/Project.toml @@ -1,11 +1,10 @@ name = "ChainRules" uuid = "082447d4-558c-5d27-93f4-14fc19e9eca2" -version = "0.7.54" +version = "0.7.55" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" From 5d880fe37bc6efa65684019022c5533b14718d23 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 24 Mar 2021 17:26:00 +0000 Subject: [PATCH 05/10] replace tests with test_rrule calls --- Project.toml | 2 + test/rulesets/LinearAlgebra/factorization.jl | 40 +++++--------------- 2 files changed, 11 insertions(+), 31 deletions(-) diff --git a/Project.toml b/Project.toml index 0144bcec2..80d00d01c 100644 --- a/Project.toml +++ b/Project.toml @@ -4,7 +4,9 @@ version = "0.7.55" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" +ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" +FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index b75731071..164acd892 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -81,20 +81,15 @@ end @testset "svd" begin for n in [4, 6, 10], m in [3, 5, 10] X = randn(n, m) - F, dX_pullback = rrule(svd, X) - for p in [:U, :S, :V, :Vt] - Y, dF_pullback = rrule(getproperty, F, p) - Ȳ = randn(size(Y)...) - - dself1, dF, dp = dF_pullback(Ȳ) - @test dself1 === NO_FIELDS - @test dp === DoesNotExist() - - dself2, dX = dX_pullback(dF) - @test dself2 === NO_FIELDS - X̄_ad = unthunk(dX) - X̄_fd = only(j′vp(central_fdm(5, 1), X->getproperty(svd(X), p), Ȳ, X)) - @test all(isapprox.(X̄_ad, X̄_fd; rtol=1e-6, atol=1e-6)) + @testset "($n by $m) svd" begin + test_rrule(svd, X) + end + @testset "($n by $m) getproperty" begin + F = svd(X) + test_rrule(getproperty, F, :U; check_inferred=false) + test_rrule(getproperty, F, :S; check_inferred=false) + test_rrule(getproperty, F, :Vt; check_inferred=false) + test_rrule(getproperty, F, :V; check_inferred=false, output_tangent=adjoint(rand(n, m))) end end @@ -122,23 +117,6 @@ end end end - @testset "+" begin - X = [1.0 2.0; 3.0 4.0; 5.0 6.0] - F, dX_pullback = rrule(svd, X) - X̄ = Composite{typeof(F)}(U=zeros(3, 2), S=zeros(2), V=zeros(2, 2)) - for p in [:U, :S, :V, :Vt] - Y, dF_pullback = rrule(getproperty, F, p) - Ȳ = ones(size(Y)...) - dself, dF, dp = dF_pullback(Ȳ) - @test dself === NO_FIELDS - @test dp === DoesNotExist() - X̄ += dF - end - @test X̄.U ≈ ones(3, 2) atol=1e-6 - @test X̄.S ≈ ones(2) atol=1e-6 - @test X̄.Vt ≈ 2 * ones(2, 2) atol=1e-6 # * 2 because V and Vt both accumulate to Vt - end - @testset "Helper functions" begin X = randn(10, 10) Y = randn(10, 10) From ac67ded18a83d7709d081f08aaaf295bff645dcf Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 24 Mar 2021 17:50:47 +0000 Subject: [PATCH 06/10] organise tests --- test/rulesets/LinearAlgebra/factorization.jl | 20 +++++++++++++------- 1 file changed, 13 insertions(+), 7 deletions(-) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 164acd892..c4dc3bb9b 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -80,16 +80,22 @@ end end @testset "svd" begin for n in [4, 6, 10], m in [3, 5, 10] - X = randn(n, m) - @testset "($n by $m) svd" begin + @testset "svd" begin + X = randn(n, m) test_rrule(svd, X) end - @testset "($n by $m) getproperty" begin + end + + for n in [4, 6, 10], m in [3, 5, 10] + @testset "getproperty" begin + X = randn(n, m) F = svd(X) - test_rrule(getproperty, F, :U; check_inferred=false) - test_rrule(getproperty, F, :S; check_inferred=false) - test_rrule(getproperty, F, :Vt; check_inferred=false) - test_rrule(getproperty, F, :V; check_inferred=false, output_tangent=adjoint(rand(n, m))) + rand_adj = adjoint(rand(reverse(size(F.V))...)) + + test_rrule(getproperty, F, :U ⊢ nothing; check_inferred=false) + test_rrule(getproperty, F, :S ⊢ nothing; check_inferred=false) + test_rrule(getproperty, F, :Vt ⊢ nothing; check_inferred=false) + test_rrule(getproperty, F, :V ⊢ nothing; check_inferred=false, output_tangent=rand_adj) end end From e60a6f08f82f01faabd6af4b7fa790240f9fb42d Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 24 Mar 2021 17:51:22 +0000 Subject: [PATCH 07/10] remove deps --- Project.toml | 2 -- 1 file changed, 2 deletions(-) diff --git a/Project.toml b/Project.toml index 80d00d01c..0144bcec2 100644 --- a/Project.toml +++ b/Project.toml @@ -4,9 +4,7 @@ version = "0.7.55" [deps] ChainRulesCore = "d360d2e6-b24c-11e9-a2a3-2a2ae2dbcce4" -ChainRulesTestUtils = "cdddcdb0-9152-4a09-a978-84456f9df70a" Compat = "34da2185-b29b-5c13-b0c7-acf172513d20" -FiniteDifferences = "26cc04aa-876d-5657-8c51-4c34ba976000" LinearAlgebra = "37e2e46d-f89d-539d-b4ee-838fcccc9c8e" Random = "9a3f8284-a2c9-5f02-9a11-845980a1fd5c" Reexport = "189a3867-3050-52da-a836-e630ba90ab69" From 285977beee3eae56c63d51c4c1e136b1c535c629 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 24 Mar 2021 18:32:24 +0000 Subject: [PATCH 08/10] i guess there was a reason these tolerances were set --- test/rulesets/LinearAlgebra/factorization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index c4dc3bb9b..4b8a150e6 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -82,7 +82,7 @@ end for n in [4, 6, 10], m in [3, 5, 10] @testset "svd" begin X = randn(n, m) - test_rrule(svd, X) + test_rrule(svd, X; atol=1e-6, rtol=1e-6) end end From 81b2e08d59c731893de230d584882b1c143b3fa9 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 24 Mar 2021 20:22:07 +0000 Subject: [PATCH 09/10] avoid numerical problems in fd --- test/rulesets/LinearAlgebra/factorization.jl | 9 +++++---- 1 file changed, 5 insertions(+), 4 deletions(-) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 4b8a150e6..2f39c3b75 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -79,15 +79,16 @@ end end end @testset "svd" begin - for n in [4, 6, 10], m in [3, 5, 10] - @testset "svd" begin + for n in [4, 6, 10], m in [3, 5, 9] + @testset "($n x $m) svd" begin X = randn(n, m) - test_rrule(svd, X; atol=1e-6, rtol=1e-6) + @show X + test_rrule(svd, X) end end for n in [4, 6, 10], m in [3, 5, 10] - @testset "getproperty" begin + @testset "($n x $m) getproperty" begin X = randn(n, m) F = svd(X) rand_adj = adjoint(rand(reverse(size(F.V))...)) From 9e727574901b6f408991349a04929fd550bc2157 Mon Sep 17 00:00:00 2001 From: Miha Zgubic Date: Wed, 24 Mar 2021 22:25:50 +0000 Subject: [PATCH 10/10] relax tolerance --- test/rulesets/LinearAlgebra/factorization.jl | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 2f39c3b75..71294c812 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -83,7 +83,7 @@ end @testset "($n x $m) svd" begin X = randn(n, m) @show X - test_rrule(svd, X) + test_rrule(svd, X; atol=1e-6, rtol=1e-6) end end