From ba899016b7b83eeaf89c35325636a550fb69241f Mon Sep 17 00:00:00 2001 From: Per Rutquist Date: Fri, 10 Nov 2023 12:25:09 +0100 Subject: [PATCH 1/5] More efficient projection in svd pullback The formula (I - U*U')*X can be extremely slow and memory-intensive when U and X are very tall matrices. Replacing it with the equivalent X - U*(U'*X) in two places. --- src/rulesets/LinearAlgebra/factorization.jl | 8 ++++---- src/rulesets/LinearAlgebra/utils.jl | 9 --------- test/rulesets/LinearAlgebra/factorization.jl | 1 - 3 files changed, 4 insertions(+), 14 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 910dd744b..60afe0a42 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -262,16 +262,16 @@ function svd_rev(USV::SVD, Ū, s̄, V̄) Ut = U' FUᵀŪ = _mulsubtrans!!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU) FVᵀV̄ = _mulsubtrans!!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV) - ImUUᵀ = _eyesubx!(U*Ut) # I - UUᵀ - ImVVᵀ = _eyesubx!(V*Vt) # I - VVᵀ S = Diagonal(s) S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(s̄) # TODO: consider using MuladdMacro here - Ā = add!!(U * FUᵀŪ * S, ImUUᵀ * (Ū / S)) * Vt + Ūs = Ū / S + V̄ts = S \ V̄' + Ā = add!!(U * FUᵀŪ * S, Ūs - U * (Ut * Ūs)) * Vt Ā = add!!(Ā, U * S̄ * Vt) - Ā = add!!(Ā, U * add!!(S * FVᵀV̄ * Vt, (S \ V̄') * ImVVᵀ)) + Ā = add!!(Ā, U * add!!(S * FVᵀV̄ * Vt, V̄ts - (V̄ts * V) * Vt)) return Ā end diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index 3d8ad923f..71a30ecc7 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -19,15 +19,6 @@ _mulsubtrans!!(X::AbstractZero, F::AbstractZero) = X _mulsubtrans!!(X::AbstractZero, F::AbstractMatrix{<:Real}) = X _mulsubtrans!!(X::AbstractMatrix{<:Real}, F::AbstractZero) = F -# I - X, overwrites X -function _eyesubx!(X::AbstractMatrix) - n, m = size(X) - @inbounds for j = 1:m, i = 1:n - X[i,j] = (i == j) - X[i,j] - end - return X -end - _extract_imag(x) = complex(0, imag(x)) """ diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 60e2e74be..52a238ff4 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -151,7 +151,6 @@ end X = randn(10, 10) Y = randn(10, 10) @test ChainRules._mulsubtrans!!(copy(X), Y) ≈ Y .* (X - X') - @test ChainRules._eyesubx!(copy(X)) ≈ I - X Z = randn(Float32, 10, 10) result = ChainRules._mulsubtrans!!(copy(Z), Y) From f288c7a5c300724198f1c4a82cefbf14cc310e25 Mon Sep 17 00:00:00 2001 From: Per Rutquist Date: Mon, 4 Dec 2023 15:44:25 +0100 Subject: [PATCH 2/5] Rewriting rev_svd to (hopefully) be faster This uses fewer matrix multiplications. The code no longer uses the helper function _mulsubtrans!! so it has been removed. --- src/rulesets/LinearAlgebra/factorization.jl | 31 ++++++++++---------- src/rulesets/LinearAlgebra/utils.jl | 18 ------------ test/rulesets/LinearAlgebra/factorization.jl | 11 ------- 3 files changed, 16 insertions(+), 44 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index 60afe0a42..e71049537 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -216,7 +216,7 @@ end ##### function _svd_pullback(Ȳ::Tangent, F) - ∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt') + ∂X = svd_rev(F, Ȳ.U, Ȳ.S, Ȳ.Vt) return (NoTangent(), ∂X) end _svd_pullback(Ȳ::AbstractThunk, F) = _svd_pullback(unthunk(Ȳ), F) @@ -244,34 +244,35 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD end # When not `ZeroTangent`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix` -function svd_rev(USV::SVD, Ū, s̄, V̄) +function svd_rev(USV::SVD, Ū, s̄, V̄t) # Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default U = USV.U s = USV.S - V = USV.V Vt = USV.Vt k = length(s) T = eltype(s) F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k] - # We do a lot of matrix operations here, so we'll try to be memory-friendly and do - # as many of the computations in-place as possible. Benchmarking shows that the in- - # place functions here are significantly faster than their out-of-place, naively - # implemented counterparts, and allocate no additional memory. - Ut = U' - FUᵀŪ = _mulsubtrans!!(Ut*Ū, F) # F .* (UᵀŪ - ŪᵀU) - FVᵀV̄ = _mulsubtrans!!(Vt*V̄, F) # F .* (VᵀV̄ - V̄ᵀV) + UtŪ = U'*Ū + V̄tV = V̄t*Vt' + FUᵀŪS = F .* (UtŪ .- UtŪ') .* s' + SFVᵀV̄ = F .* (V̄tV' .- V̄tV) .* s + S = Diagonal(s) S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(s̄) + + Ā = U * (FUᵀŪS + S̄ + SFVᵀV̄) * Vt # TODO: consider using MuladdMacro here - Ūs = Ū / S - V̄ts = S \ V̄' - Ā = add!!(U * FUᵀŪ * S, Ūs - U * (Ut * Ūs)) * Vt - Ā = add!!(Ā, U * S̄ * Vt) - Ā = add!!(Ā, U * add!!(S * FVᵀV̄ * Vt, V̄ts - (V̄ts * V) * Vt)) + if size(U,1) > size(U,2) + Ā = add!!(Ā, ((Ū .- U * UtŪ) / S) * Vt) + end + + if size(Vt,2) > size(Vt,1) + Ā = add!!(Ā, U * (S \ (V̄t .- V̄tV * Vt))) + end return Ā end diff --git a/src/rulesets/LinearAlgebra/utils.jl b/src/rulesets/LinearAlgebra/utils.jl index 71a30ecc7..f13758f0e 100644 --- a/src/rulesets/LinearAlgebra/utils.jl +++ b/src/rulesets/LinearAlgebra/utils.jl @@ -1,24 +1,6 @@ # Some utility functions for optimizing linear algebra operations that aren't specific # to any particular rule definition -# F .* (X - X'), overwrites X if possible -function _mulsubtrans!!(X::AbstractMatrix{<:Real}, F::AbstractMatrix{<:Real}) - T = promote_type(eltype(X), eltype(F)) - Y = (T <: eltype(X)) ? X : similar(X, T) - k = size(X, 1) - @inbounds for j = 1:k, i = 1:j # Iterate the upper triangle - if i == j - Y[i,i] = zero(T) - else - Y[i,j], Y[j,i] = F[i,j] * (X[i,j] - X[j,i]), F[j,i] * (X[j,i] - X[i,j]) - end - end - return Y -end -_mulsubtrans!!(X::AbstractZero, F::AbstractZero) = X -_mulsubtrans!!(X::AbstractZero, F::AbstractMatrix{<:Real}) = X -_mulsubtrans!!(X::AbstractMatrix{<:Real}, F::AbstractZero) = F - _extract_imag(x) = complex(0, imag(x)) """ diff --git a/test/rulesets/LinearAlgebra/factorization.jl b/test/rulesets/LinearAlgebra/factorization.jl index 52a238ff4..3832f63d5 100644 --- a/test/rulesets/LinearAlgebra/factorization.jl +++ b/test/rulesets/LinearAlgebra/factorization.jl @@ -146,17 +146,6 @@ end @test dX_thunked == dX_unthunked end end - - @testset "Helper functions" begin - X = randn(10, 10) - Y = randn(10, 10) - @test ChainRules._mulsubtrans!!(copy(X), Y) ≈ Y .* (X - X') - - Z = randn(Float32, 10, 10) - result = ChainRules._mulsubtrans!!(copy(Z), Y) - @test result ≈ Y .* (Z - Z') - @test eltype(result) == Float64 - end end @testset "eigendecomposition" begin From ef235e15fd6d3c61248a7bd3581033c193c313a5 Mon Sep 17 00:00:00 2001 From: Per Rutquist Date: Mon, 4 Dec 2023 20:20:43 +0100 Subject: [PATCH 3/5] Re-write svd_rev to avoid a matrix multiplication --- src/rulesets/LinearAlgebra/factorization.jl | 15 ++++++--------- 1 file changed, 6 insertions(+), 9 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index e71049537..cd73654b5 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -263,15 +263,12 @@ function svd_rev(USV::SVD, Ū, s̄, V̄t) S = Diagonal(s) S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(s̄) - Ā = U * (FUᵀŪS + S̄ + SFVᵀV̄) * Vt - - # TODO: consider using MuladdMacro here - if size(U,1) > size(U,2) - Ā = add!!(Ā, ((Ū .- U * UtŪ) / S) * Vt) - end - - if size(Vt,2) > size(Vt,1) - Ā = add!!(Ā, U * (S \ (V̄t .- V̄tV * Vt))) + if size(Vt,1) == size(Vt,2) + # V is square, VVᵀ = I and therefore V̄ᵀ - V̄ᵀVVᵀ = 0 + Ā = (U * (FUᵀŪS + S̄ + SFVᵀV̄) + ((Ū .- U * UtŪ) / S)) * Vt + else + # If V is not square then U is, so UUᵀ == I and Ū - UUᵀŪ = 0 + Ā = U * ((FUᵀŪS + S̄ + SFVᵀV̄) * Vt + (S \ (V̄t .- V̄tV * Vt))) end return Ā From 225dd72d43359cf4b6092975f3f7f2cdf2a8f788 Mon Sep 17 00:00:00 2001 From: Per Rutquist Date: Tue, 5 Dec 2023 22:44:50 +0100 Subject: [PATCH 4/5] Rewrite svd_rev to reduce allocations --- src/rulesets/LinearAlgebra/factorization.jl | 33 ++++++++++----------- 1 file changed, 16 insertions(+), 17 deletions(-) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index cd73654b5..a1bb400a0 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -243,7 +243,7 @@ function rrule(::typeof(getproperty), F::T, x::Symbol) where T <: SVD return getproperty(F, x), getproperty_svd_pullback end -# When not `ZeroTangent`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄::AbstractMatrix` +# When not `ZeroTangent`s expect `Ū::AbstractMatrix, s̄::AbstractVector, V̄t::AbstractMatrix` function svd_rev(USV::SVD, Ū, s̄, V̄t) # Note: assuming a thin factorization, i.e. svd(A, full=false), which is the default U = USV.U @@ -252,25 +252,24 @@ function svd_rev(USV::SVD, Ū, s̄, V̄t) k = length(s) T = eltype(s) - F = T[i == j ? 1 : inv(@inbounds s[j]^2 - s[i]^2) for i = 1:k, j = 1:k] - - UtŪ = U'*Ū - V̄tV = V̄t*Vt' - - FUᵀŪS = F .* (UtŪ .- UtŪ') .* s' - SFVᵀV̄ = F .* (V̄tV' .- V̄tV) .* s - - S = Diagonal(s) - S̄ = s̄ isa AbstractZero ? s̄ : Diagonal(s̄) - - if size(Vt,1) == size(Vt,2) + UtŪ = U' * Ū + V̄tV = V̄t * Vt' + M = @inbounds T[ + if i == j + s̄[i] + else + (s[j] * (UtŪ[i, j] - UtŪ[j, i]) + s[i] * (V̄tV[j, i] - V̄tV[i, j])) / + (s[j]^2 - s[i]^2) + end for i in 1:k, j in 1:k + ] + + if size(Vt, 1) == size(Vt, 2) # V is square, VVᵀ = I and therefore V̄ᵀ - V̄ᵀVVᵀ = 0 - Ā = (U * (FUᵀŪS + S̄ + SFVᵀV̄) + ((Ū .- U * UtŪ) / S)) * Vt - else + Ā = (U * M .+ ((Ū .- U * UtŪ) ./ s')) * Vt + else # If V is not square then U is, so UUᵀ == I and Ū - UUᵀŪ = 0 - Ā = U * ((FUᵀŪS + S̄ + SFVᵀV̄) * Vt + (S \ (V̄t .- V̄tV * Vt))) + Ā = U * (M * Vt .+ ((V̄t .- V̄tV * Vt) ./ s)) end - return Ā end From 231ac1a0dda79325826e5018ae2dc0711d568b05 Mon Sep 17 00:00:00 2001 From: Per Rutquist Date: Wed, 15 May 2024 10:03:52 +0200 Subject: [PATCH 5/5] Adding a specialized method of svd_rev where only s-bar is nonzero. The general-case method also works in this case, but is slightly slower because it creates a dense matrix M even though only the diagonal entries are nonzero. --- src/rulesets/LinearAlgebra/factorization.jl | 5 +++++ 1 file changed, 5 insertions(+) diff --git a/src/rulesets/LinearAlgebra/factorization.jl b/src/rulesets/LinearAlgebra/factorization.jl index a1bb400a0..3a3a1f821 100644 --- a/src/rulesets/LinearAlgebra/factorization.jl +++ b/src/rulesets/LinearAlgebra/factorization.jl @@ -273,6 +273,11 @@ function svd_rev(USV::SVD, Ū, s̄, V̄t) return Ā end +function svd_rev(USV::SVD, ::AbstractZero, s̄::AbstractVector, ::AbstractZero) + Ā = USV.U * Diagonal(s̄) * USV.Vt + return Ā +end + ##### ##### `eigen` #####