From a66ad918ef3cd325393fe8bbb3e846e921da267a Mon Sep 17 00:00:00 2001 From: Seth Axen Date: Fri, 6 Jan 2023 00:16:39 +0100 Subject: [PATCH] Add missing methods for Woodbury matrices (#117) * Add additional tests * Update methods * Increment patch number * Remove scalars --- Project.toml | 2 +- src/woodbury.jl | 25 +++++++++++++++++-------- test/woodbury.jl | 31 ++++++++++++++++++++++++++++++- 3 files changed, 48 insertions(+), 10 deletions(-) diff --git a/Project.toml b/Project.toml index 3a0b6955..2ce5f9c8 100644 --- a/Project.toml +++ b/Project.toml @@ -1,7 +1,7 @@ name = "Pathfinder" uuid = "b1d3bc72-d0e7-4279-b92f-7fa5d6d2d454" authors = ["Seth Axen and contributors"] -version = "0.6.0" +version = "0.6.1" [deps] AbstractDifferentiation = "c29ec348-61ec-40c8-8164-b8c60e9d9f3d" diff --git a/src/woodbury.jl b/src/woodbury.jl index fcafcc31..4937afd3 100644 --- a/src/woodbury.jl +++ b/src/woodbury.jl @@ -117,13 +117,13 @@ function LinearAlgebra.mul!( r::StridedVecOrMat{T}, R::WoodburyPDRightFactor{T}, x::StridedVecOrMat{T} ) where {T} copyto!(r, x) - return lmul!(R, copyto!(r, x)) + return lmul!(R, r) end - -function Base.:*(R::WoodburyPDRightFactor, x::StridedVecOrMat) - T = Base.promote_eltype(R, x) - y = copyto!(similar(x, T), x) - return lmul!(R, y) +function LinearAlgebra.mul!( + r::StridedVecOrMat{T}, L::WoodburyPDLeftFactor{T}, x::StridedVecOrMat{T} +) where {T} + copyto!(r, x) + return lmul!(L, r) end function LinearAlgebra.lmul!(R::WoodburyPDRightFactor, x::StridedVecOrMat) @@ -329,8 +329,17 @@ function LinearAlgebra.lmul!(W::WoodburyPDMat, x::AbstractVecOrMat) return lmul!(factorize(W), x) end -function LinearAlgebra.mul!(y::AbstractVector, W::WoodburyPDMat, x::AbstractVecOrMat) - return lmul!(W, copyto!(y, x)) +LinearAlgebra.ldiv!(W::WoodburyPDMat, x::AbstractVecOrMat) = ldiv!(factorize(W), x) + +function LinearAlgebra.mul!(y::AbstractVecOrMat, W::WoodburyPDMat, x::AbstractVecOrMat) + copyto!(y, x) + return lmul!(W, y) +end + +function Base.:\(W::WoodburyPDMat, x::AbstractVecOrMat) + y = similar(x, Base.promote_eltype(W, x)) + copyto!(y, x) + return ldiv!(W, y) end function Base.:*(W::WoodburyPDMat, c::Real) diff --git a/test/woodbury.jl b/test/woodbury.jl index a59c848e..56d9d9df 100644 --- a/test/woodbury.jl +++ b/test/woodbury.jl @@ -56,6 +56,7 @@ test_factorization(W::WoodburyPDMat) = test_factorization(W.A, W.B, W.D, W.F) @test Z \ x ≈ Zmat \ x @test Z' \ x ≈ Zmat' \ x @test mul!(similar(x), Z, x) ≈ Zmat * x + @test mul!(similar(x), Z', x) ≈ Zmat' * x @test lmul!(Z, copy(x)) ≈ Zmat * x @test ldiv!(Z, copy(x)) ≈ Zmat \ x end @@ -230,6 +231,18 @@ end @test X ≈ Y end + @testset "ldiv!" begin + x = randn(T, n) + y = Wmat \ x + @test ldiv!(W, x) === x + @test x ≈ y + + X = randn(T, n, 5) + Y = Wmat \ X + @test ldiv!(W, X) === X + @test X ≈ Y + end + @testset "mul!" begin x = randn(T, n) y = similar(x) @@ -254,10 +267,26 @@ end x = randn(T, n) @test W * x ≈ Wmat * x - X = randn(T, n) + X = randn(T, n, 2) @test W * X ≈ Wmat * X end + @testset "\\" begin + x = randn(T, n) + @test W \ x ≈ Wmat \ x + + X = randn(T, n, 2) + @test W \ X ≈ Wmat \ X + end + + @testset "/" begin + x = randn(T, n) + @test x' / W ≈ x' / Wmat + + X = randn(T, 2, n) + @test X / W ≈ X / Wmat + end + @testset "PDMats.dim" begin @test PDMats.dim(W) == n end