Skip to content

Commit

Permalink
Make adjoint and solve work for most factorizations (#40899)
Browse files Browse the repository at this point in the history
* Add adjoint for Cholesky

* Implement adjoint for BunchKaufman

* Fix ldiv! for adjoints of Hessenbergs

* Add adjoint of LDLt

* Fix return for tall problems in fallback \ method for adjoint of
Factorizations to make \ work for adjoint LQ.

* Fix qr(A)'\b

* Define adjoint for SVD

* Improve promotion in fallback by defining general convert methods
for Factorizations

* Fix ldiv! for SVD

* Restrict the general \ definition that handles over- and underdetermined
systems to LAPACK factorizations

* Remove redundant \ definitions in diagonal.jl

* Add Factorization constructors for SVD

* Disambiguate between the specialized \ for real lhs-complex rhs and
then new \ for LAPACKFactorizations.

* Adjustments based on review

* Fixes for new pivoting syntax
  • Loading branch information
andreasnoack authored May 30, 2021
1 parent 311ff56 commit acdffeb
Show file tree
Hide file tree
Showing 17 changed files with 272 additions and 75 deletions.
57 changes: 57 additions & 0 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -397,6 +397,63 @@ const ⋅ = dot
const × = cross
export , ×

## convenience methods
## return only the solution of a least squares problem while avoiding promoting
## vectors to matrices.
_cut_B(x::AbstractVector, r::UnitRange) = length(x) > length(r) ? x[r] : x
_cut_B(X::AbstractMatrix, r::UnitRange) = size(X, 1) > length(r) ? X[r,:] : X

## append right hand side with zeros if necessary
_zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n))
_zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2))

# General fallback definition for handling under- and overdetermined system as well as square problems
# While this definition is pretty general, it does e.g. promote to common element type of lhs and rhs
# which is required by LAPACK but not SuiteSpase which allows real-complex solves in some cases. Hence,
# we restrict this method to only the LAPACK factorizations in LinearAlgebra.
# The definition is put here since it explicitly references all the Factorizion structs so it has
# to be located after all the files that define the structs.
const LAPACKFactorizations{T,S} = Union{
BunchKaufman{T,S},
Cholesky{T,S},
LQ{T,S},
LU{T,S},
QR{T,S},
QRCompactWY{T,S},
QRPivoted{T,S},
SVD{T,<:Real,S}}
function (\)(F::Union{<:LAPACKFactorizations,Adjoint{<:Any,<:LAPACKFactorizations}}, B::AbstractVecOrMat)
require_one_based_indexing(B)
m, n = size(F)
if m != size(B, 1)
throw(DimensionMismatch("arguments must have the same number of rows"))
end

TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
FF = Factorization{TFB}(F)

# For wide problem we (often) compute a minimum norm solution. The solution
# is larger than the right hand side so we use size(F, 2).
BB = _zeros(TFB, B, n)

if n > size(B, 1)
# Underdetermined
copyto!(view(BB, 1:m, :), B)
else
copyto!(BB, B)
end

ldiv!(FF, BB)

# For tall problems, we compute a least squares solution so only part
# of the rhs should be returned from \ while ldiv! uses (and returns)
# the complete rhs
return _cut_B(BB, 1:n)
end
# disambiguate
(\)(F::LAPACKFactorizations{T}, B::VecOrMat{Complex{T}}) where {T<:BlasReal} =
invoke(\, Tuple{Factorization{T}, VecOrMat{Complex{T}}}, F, B)

"""
LinearAlgebra.peakflops(n::Integer=2000; parallel::Bool=false)
Expand Down
16 changes: 11 additions & 5 deletions stdlib/LinearAlgebra/src/bunchkaufman.jl
Original file line number Diff line number Diff line change
Expand Up @@ -196,16 +196,14 @@ julia> S.L*S.D*S.L' - A[S.p, S.p]
bunchkaufman(A::AbstractMatrix{T}, rook::Bool=false; check::Bool = true) where {T} =
bunchkaufman!(copy_oftype(A, typeof(sqrt(oneunit(T)))), rook; check = check)

convert(::Type{BunchKaufman{T}}, B::BunchKaufman{T}) where {T} = B
convert(::Type{BunchKaufman{T}}, B::BunchKaufman) where {T} =
BunchKaufman{T}(B::BunchKaufman) where {T} =
BunchKaufman(convert(Matrix{T}, B.LD), B.ipiv, B.uplo, B.symmetric, B.rook, B.info)
convert(::Type{Factorization{T}}, B::BunchKaufman{T}) where {T} = B
convert(::Type{Factorization{T}}, B::BunchKaufman) where {T} = convert(BunchKaufman{T}, B)
Factorization{T}(B::BunchKaufman) where {T} = BunchKaufman{T}(B)

size(B::BunchKaufman) = size(getfield(B, :LD))
size(B::BunchKaufman, d::Integer) = size(getfield(B, :LD), d)
issymmetric(B::BunchKaufman) = B.symmetric
ishermitian(B::BunchKaufman) = !B.symmetric
ishermitian(B::BunchKaufman{T}) where T = T<:Real || !B.symmetric

function _ipiv2perm_bk(v::AbstractVector{T}, maxi::Integer, uplo::AbstractChar, rook::Bool) where T
require_one_based_indexing(v)
Expand Down Expand Up @@ -279,6 +277,14 @@ Base.propertynames(B::BunchKaufman, private::Bool=false) =

issuccess(B::BunchKaufman) = B.info == 0

function adjoint(B::BunchKaufman)
if ishermitian(B)
return B
else
throw(ArgumentError("adjoint not implemented for complex symmetric matrices"))
end
end

function Base.show(io::IO, mime::MIME{Symbol("text/plain")}, B::BunchKaufman)
if issuccess(B)
summary(io, B); println(io)
Expand Down
2 changes: 2 additions & 0 deletions stdlib/LinearAlgebra/src/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -529,6 +529,8 @@ Base.propertynames(F::CholeskyPivoted, private::Bool=false) =

issuccess(C::Union{Cholesky,CholeskyPivoted}) = C.info == 0

adjoint(C::Union{Cholesky,CholeskyPivoted}) = C

function show(io::IO, mime::MIME{Symbol("text/plain")}, C::Cholesky{<:Any,<:AbstractMatrix})
if issuccess(C)
summary(io, C); println(io)
Expand Down
8 changes: 0 additions & 8 deletions stdlib/LinearAlgebra/src/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -488,14 +488,6 @@ rdiv!(A::AbstractMatrix{T}, transD::Transpose{<:Any,<:Diagonal{T}}) where {T} =
(/)(A::Union{StridedMatrix, AbstractTriangular}, D::Diagonal) =
rdiv!((typeof(oneunit(eltype(D))/oneunit(eltype(A)))).(A), D)

(\)(F::Factorization, D::Diagonal) =
ldiv!(F, Matrix{typeof(oneunit(eltype(D))/oneunit(eltype(F)))}(D))
\(adjF::Adjoint{<:Any,<:Factorization}, D::Diagonal) =
(F = adjF.parent; ldiv!(adjoint(F), Matrix{typeof(oneunit(eltype(D))/oneunit(eltype(F)))}(D)))
(\)(A::Union{QR,QRCompactWY,QRPivoted}, B::Diagonal) =
invoke(\, Tuple{Union{QR,QRCompactWY,QRPivoted}, AbstractVecOrMat}, A, B)


@inline function kron!(C::AbstractMatrix, A::Diagonal, B::Diagonal)
valA = A.diag; nA = length(valA)
valB = B.diag; nB = length(valB)
Expand Down
24 changes: 6 additions & 18 deletions stdlib/LinearAlgebra/src/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,9 @@ convert(::Type{T}, f::Factorization) where {T<:AbstractArray} = T(f)

### General promotion rules
Factorization{T}(F::Factorization{T}) where {T} = F
# This is a bit odd since the return is not a Factorization but it works well in generic code
Factorization{T}(A::Adjoint{<:Any,<:Factorization}) where {T} =
adjoint(Factorization{T}(parent(A)))
inv(F::Factorization{T}) where {T} = (n = size(F, 1); ldiv!(F, Matrix{T}(I, n, n)))

Base.hash(F::Factorization, h::UInt) = mapreduce(f -> hash(getfield(F, f)), hash, 1:nfields(F); init=h)
Expand Down Expand Up @@ -96,40 +99,25 @@ function (/)(B::VecOrMat{Complex{T}}, F::Factorization{T}) where T<:BlasReal
return copy(reinterpret(Complex{T}, x))
end

function \(F::Factorization, B::AbstractVecOrMat)
function \(F::Union{Factorization, Adjoint{<:Any,<:Factorization}}, B::AbstractVecOrMat)
require_one_based_indexing(B)
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
BB = similar(B, TFB, size(B))
copyto!(BB, B)
ldiv!(F, BB)
end
function \(adjF::Adjoint{<:Any,<:Factorization}, B::AbstractVecOrMat)
require_one_based_indexing(B)
F = adjF.parent
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
BB = similar(B, TFB, size(B))
copyto!(BB, B)
ldiv!(adjoint(F), BB)
end

function /(B::AbstractMatrix, F::Factorization)
function /(B::AbstractMatrix, F::Union{Factorization, Adjoint{<:Any,<:Factorization}})
require_one_based_indexing(B)
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
BB = similar(B, TFB, size(B))
copyto!(BB, B)
rdiv!(BB, F)
end
function /(B::AbstractMatrix, adjF::Adjoint{<:Any,<:Factorization})
require_one_based_indexing(B)
F = adjF.parent
TFB = typeof(oneunit(eltype(B)) / oneunit(eltype(F)))
BB = similar(B, TFB, size(B))
copyto!(BB, B)
rdiv!(BB, adjoint(F))
end
/(adjB::AdjointAbsVec, adjF::Adjoint{<:Any,<:Factorization}) = adjoint(adjF.parent \ adjB.parent)
/(B::TransposeAbsVec, adjF::Adjoint{<:Any,<:Factorization}) = adjoint(adjF.parent \ adjoint(B))


# support the same 3-arg idiom as in our other in-place A_*_B functions:
function ldiv!(Y::AbstractVecOrMat, A::Factorization, B::AbstractVecOrMat)
require_one_based_indexing(Y, B)
Expand Down
14 changes: 8 additions & 6 deletions stdlib/LinearAlgebra/src/hessenberg.jl
Original file line number Diff line number Diff line change
Expand Up @@ -564,28 +564,30 @@ function AbstractMatrix(F::Hessenberg)
end
end

# adjoint(Q::HessenbergQ{<:Real})

lmul!(Q::BlasHessenbergQ{T,false}, X::StridedVecOrMat{T}) where {T<:BlasFloat} =
LAPACK.ormhr!('L', 'N', 1, size(Q.factors, 1), Q.factors, Q.τ, X)
rmul!(X::StridedMatrix{T}, Q::BlasHessenbergQ{T,false}) where {T<:BlasFloat} =
rmul!(X::StridedVecOrMat{T}, Q::BlasHessenbergQ{T,false}) where {T<:BlasFloat} =
LAPACK.ormhr!('R', 'N', 1, size(Q.factors, 1), Q.factors, Q.τ, X)
lmul!(adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,false}}, X::StridedVecOrMat{T}) where {T<:BlasFloat} =
(Q = adjQ.parent; LAPACK.ormhr!('L', ifelse(T<:Real, 'T', 'C'), 1, size(Q.factors, 1), Q.factors, Q.τ, X))
rmul!(X::StridedMatrix{T}, adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,false}}) where {T<:BlasFloat} =
rmul!(X::StridedVecOrMat{T}, adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,false}}) where {T<:BlasFloat} =
(Q = adjQ.parent; LAPACK.ormhr!('R', ifelse(T<:Real, 'T', 'C'), 1, size(Q.factors, 1), Q.factors, Q.τ, X))

lmul!(Q::BlasHessenbergQ{T,true}, X::StridedVecOrMat{T}) where {T<:BlasFloat} =
LAPACK.ormtr!('L', Q.uplo, 'N', Q.factors, Q.τ, X)
rmul!(X::StridedMatrix{T}, Q::BlasHessenbergQ{T,true}) where {T<:BlasFloat} =
rmul!(X::StridedVecOrMat{T}, Q::BlasHessenbergQ{T,true}) where {T<:BlasFloat} =
LAPACK.ormtr!('R', Q.uplo, 'N', Q.factors, Q.τ, X)
lmul!(adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,true}}, X::StridedVecOrMat{T}) where {T<:BlasFloat} =
(Q = adjQ.parent; LAPACK.ormtr!('L', Q.uplo, ifelse(T<:Real, 'T', 'C'), Q.factors, Q.τ, X))
rmul!(X::StridedMatrix{T}, adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,true}}) where {T<:BlasFloat} =
rmul!(X::StridedVecOrMat{T}, adjQ::Adjoint{<:Any,<:BlasHessenbergQ{T,true}}) where {T<:BlasFloat} =
(Q = adjQ.parent; LAPACK.ormtr!('R', Q.uplo, ifelse(T<:Real, 'T', 'C'), Q.factors, Q.τ, X))

lmul!(Q::HessenbergQ{T}, X::Adjoint{T,<:StridedVecOrMat{T}}) where {T} = rmul!(X', Q')'
rmul!(X::Adjoint{T,<:StridedMatrix{T}}, Q::HessenbergQ{T}) where {T} = lmul!(Q', X')'
rmul!(X::Adjoint{T,<:StridedVecOrMat{T}}, Q::HessenbergQ{T}) where {T} = lmul!(Q', X')'
lmul!(adjQ::Adjoint{<:Any,<:HessenbergQ{T}}, X::Adjoint{T,<:StridedVecOrMat{T}}) where {T} = rmul!(X', adjQ')'
rmul!(X::Adjoint{T,<:StridedMatrix{T}}, adjQ::Adjoint{<:Any,<:HessenbergQ{T}}) where {T} = lmul!(adjQ', X')'
rmul!(X::Adjoint{T,<:StridedVecOrMat{T}}, adjQ::Adjoint{<:Any,<:HessenbergQ{T}}) where {T} = lmul!(adjQ', X')'

# multiply x by the entries of M in the upper-k triangle, which contains
# the entries of the upper-Hessenberg matrix H for k=-1
Expand Down
3 changes: 3 additions & 0 deletions stdlib/LinearAlgebra/src/ldlt.jl
Original file line number Diff line number Diff line change
Expand Up @@ -77,6 +77,9 @@ function getproperty(F::LDLt, d::Symbol)
end
end

adjoint(F::LDLt{<:Real,<:SymTridiagonal}) = F
adjoint(F::LDLt) = LDLt(copy(adjoint(F.data)))

function show(io::IO, mime::MIME{Symbol("text/plain")}, F::LDLt)
summary(io, F); println(io)
println(io, "L factor:")
Expand Down
30 changes: 16 additions & 14 deletions stdlib/LinearAlgebra/src/lq.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,8 +125,8 @@ lq_eltype(::Type{T}) where {T} = typeof(zero(T) / sqrt(abs2(one(T))))
copy(A::LQ) = LQ(copy(A.factors), copy(A.τ))

LQ{T}(A::LQ) where {T} = LQ(convert(AbstractMatrix{T}, A.factors), convert(Vector{T}, A.τ))
Factorization{T}(A::LQ{T}) where {T} = A
Factorization{T}(A::LQ) where {T} = LQ{T}(A)

AbstractMatrix(A::LQ) = A.L*A.Q
AbstractArray(A::LQ) = AbstractMatrix(A)
Matrix(A::LQ) = Array(AbstractArray(A))
Expand Down Expand Up @@ -194,7 +194,7 @@ function lmul!(A::LQ, B::StridedVecOrMat)
end
function *(A::LQ{TA}, B::StridedVecOrMat{TB}) where {TA,TB}
TAB = promote_type(TA, TB)
_cut_B(lmul!(Factorization{TAB}(A), copy_oftype(B, TAB)), 1:size(A,1))
_cut_B(lmul!(convert(Factorization{TAB}, A), copy_oftype(B, TAB)), 1:size(A,1))
end

## Multiplication by Q
Expand Down Expand Up @@ -318,17 +318,6 @@ _rightappdimmismatch(rowsorcols) =
"or (2) the number of rows of that (LQPackedQ) matrix's internal representation ",
"(the factorization's originating matrix's number of rows)")))


function (\)(A::LQ{TA},B::StridedVecOrMat{TB}) where {TA,TB}
S = promote_type(TA,TB)
m, n = size(A)
m n || throw(DimensionMismatch("LQ solver does not support overdetermined systems (more rows than columns)"))
m == size(B,1) || throw(DimensionMismatch("Both inputs should have the same number of rows"))
AA = Factorization{S}(A)
X = _zeros(S, B, n)
X[1:size(B, 1), :] = B
return ldiv!(AA, X)
end
# With a real lhs and complex rhs with the same precision, we can reinterpret
# the complex rhs as a real rhs with twice the number of columns
function (\)(F::LQ{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
Expand All @@ -342,12 +331,25 @@ function (\)(F::LQ{T}, B::VecOrMat{Complex{T}}) where T<:BlasReal
end


function ldiv!(A::LQ{T}, B::StridedVecOrMat{T}) where T
function ldiv!(A::LQ, B::StridedVecOrMat)
require_one_based_indexing(B)
m, n = size(A)
m n || throw(DimensionMismatch("LQ solver does not support overdetermined systems (more rows than columns)"))

ldiv!(LowerTriangular(A.L), view(B, 1:size(A,1), axes(B,2)))
return lmul!(adjoint(A.Q), B)
end

function ldiv!(Fadj::Adjoint{<:Any,<:LQ}, B::StridedVecOrMat)
require_one_based_indexing(B)
m, n = size(Fadj)
m >= n || throw(DimensionMismatch("solver does not support underdetermined systems (more columns than rows)"))

F = parent(Fadj)
lmul!(F.Q, B)
ldiv!(UpperTriangular(adjoint(F.L)), view(B, 1:size(F,1), axes(B,2)))
return B
end

# In LQ factorization, `Q` is expressed as the product of the adjoint of the
# reflectors. Thus, `det` has to be conjugated.
Expand Down
45 changes: 27 additions & 18 deletions stdlib/LinearAlgebra/src/qr.jl
Original file line number Diff line number Diff line change
Expand Up @@ -472,6 +472,8 @@ end
Base.propertynames(F::QRPivoted, private::Bool=false) =
(:R, :Q, :p, :P, (private ? fieldnames(typeof(F)) : ())...)

adjoint(F::Union{QR,QRPivoted,QRCompactWY}) = Adjoint(F)

abstract type AbstractQ{T} <: AbstractMatrix{T} end

inv(Q::AbstractQ) = Q'
Expand Down Expand Up @@ -939,28 +941,35 @@ function ldiv!(A::QRPivoted, B::StridedMatrix)
B
end

# convenience methods
## return only the solution of a least squares problem while avoiding promoting
## vectors to matrices.
_cut_B(x::AbstractVector, r::UnitRange) = length(x) > length(r) ? x[r] : x
_cut_B(X::AbstractMatrix, r::UnitRange) = size(X, 1) > length(r) ? X[r,:] : X

## append right hand side with zeros if necessary
_zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n))
_zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2))
function _apply_permutation!(F::QRPivoted, B::AbstractVecOrMat)
# Apply permutation but only to the top part of the solution vector since
# it's padded with zeros for underdetermined problems
B[1:length(F.p), :] = B[F.p, :]
return B
end
_apply_permutation!(F::Factorization, B::AbstractVecOrMat) = B

function (\)(A::Union{QR{TA},QRCompactWY{TA},QRPivoted{TA}}, B::AbstractVecOrMat{TB}) where {TA,TB}
function ldiv!(Fadj::Adjoint{<:Any,<:Union{QR,QRCompactWY,QRPivoted}}, B::AbstractVecOrMat)
require_one_based_indexing(B)
S = promote_type(TA,TB)
m, n = size(A)
m == size(B,1) || throw(DimensionMismatch("Both inputs should have the same number of rows"))
m, n = size(Fadj)

AA = Factorization{S}(A)
# We don't allow solutions overdetermined systems
if m > n
throw(DimensionMismatch("overdetermined systems are not supported"))
end
if n != size(B, 1)
throw(DimensionMismatch("inputs should have the same number of rows"))
end
F = parent(Fadj)

X = _zeros(S, B, n)
X[1:size(B, 1), :] = B
ldiv!(AA, X)
return _cut_B(X, 1:n)
B = _apply_permutation!(F, B)

# For underdetermined system, the triangular solve should only be applied to the top
# part of B that contains the rhs. For square problems, the view corresponds to B itself
ldiv!(LowerTriangular(adjoint(F.R)), view(B, 1:size(F.R, 2), :))
lmul!(F.Q, B)

return B
end

# With a real lhs and complex rhs with the same precision, we can reinterpret the complex
Expand Down
14 changes: 12 additions & 2 deletions stdlib/LinearAlgebra/src/svd.jl
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,11 @@ function SVD{T}(U::AbstractArray, S::AbstractVector{Tr}, Vt::AbstractArray) wher
convert(AbstractArray{T}, Vt))
end

SVD{T}(F::SVD) where {T} = SVD(
convert(AbstractMatrix{T}, F.U),
convert(AbstractVector{real(T)}, F.S),
convert(AbstractMatrix{T}, F.Vt))
Factorization{T}(F::SVD) where {T} = SVD{T}(F)

# iteration for destructuring into components
Base.iterate(S::SVD) = (S.U, Val(:S))
Expand Down Expand Up @@ -235,10 +240,11 @@ svdvals(A::AbstractVector{<:BlasFloat}) = [norm(A)]
svdvals(x::Number) = abs(x)
svdvals(S::SVD{<:Any,T}) where {T} = (S.S)::Vector{T}

# SVD least squares
### SVD least squares ###
function ldiv!(A::SVD{T}, B::StridedVecOrMat) where T
m, n = size(A)
k = searchsortedlast(A.S, eps(real(T))*A.S[1], rev=true)
view(A.Vt,1:k,:)' * (view(A.S,1:k) .\ (view(A.U,:,1:k)' * B))
return mul!(view(B, 1:n, :), view(A.Vt, 1:k, :)', view(A.S, 1:k) .\ (view(A.U, :, 1:k)' * _cut_B(B, 1:m)))
end

function inv(F::SVD{T}) where T
Expand All @@ -252,6 +258,10 @@ end
size(A::SVD, dim::Integer) = dim == 1 ? size(A.U, dim) : size(A.Vt, dim)
size(A::SVD) = (size(A, 1), size(A, 2))

function adjoint(F::SVD)
return SVD(F.Vt', F.S, F.U')
end

function show(io::IO, mime::MIME{Symbol("text/plain")}, F::SVD{<:Any,<:Any,<:AbstractArray})
summary(io, F); println(io)
println(io, "U factor:")
Expand Down
Loading

0 comments on commit acdffeb

Please sign in to comment.