Skip to content

Commit

Permalink
Cleanup copy_oftype-like functions in factorizations (#43700)
Browse files Browse the repository at this point in the history
Co-authored-by: Daan Huybrechs <[email protected]>
  • Loading branch information
dkarrasch and daanhb authored Jan 14, 2022
1 parent b57d2e1 commit f15b4c3
Show file tree
Hide file tree
Showing 6 changed files with 46 additions and 59 deletions.
21 changes: 4 additions & 17 deletions stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -366,11 +366,11 @@ In general, the type of the output corresponds to that of `similar(A, T)`.
There are three often used methods in LinearAlgebra to create a mutable copy
of an array with a given eltype. These copies can be passed to in-place
algorithms (such as ldiv!, rdiv!, lu! and so on). Which one to use in practice
algorithms (such as `ldiv!`, `rdiv!`, `lu!` and so on). Which one to use in practice
depends on what is known (or assumed) about the structure of the array in that
algorithm.
See also: `copy_similar`, `copy_to_array`.
See also: `copy_similar`.
"""
copy_oftype(A::AbstractArray, ::Type{T}) where {T} = copyto!(similar(A, T), A)

Expand All @@ -380,25 +380,12 @@ copy_oftype(A::AbstractArray, ::Type{T}) where {T} = copyto!(similar(A, T), A)
Copy `A` to a mutable array with eltype `T` based on `similar(A, T, size(A))`.
Compared to `copy_oftype`, the result can be more flexible. In general, the type
of the output corresponds to that of the three-argument method `similar(A, T, size(s))`.
of the output corresponds to that of the three-argument method `similar(A, T, size(A))`.
See also: `copy_oftype`, `copy_to_array`.
See also: `copy_oftype`.
"""
copy_similar(A::AbstractArray, ::Type{T}) where {T} = copyto!(similar(A, T, size(A)), A)

"""
copy_to_array(A, T)
Copy `A` to a regular dense `Array` with element type `T`.
The resulting array is mutable. It can be used, for example, to pass the data of
`A` to an efficient in-place method for a matrix factorization such as `lu!`, in
cases where a more specific implementation of `lu!` (or `lu`) is not available.
See also: `copy_oftype`, `copy_similar`.
"""
copy_to_array(A::AbstractArray, ::Type{T}) where {T} = copyto!(Array{T}(undef, size(A)...), A)

# The three copy functions above return mutable arrays with eltype T.
# To only ensure a certain eltype, and if a mutable copy is not needed, it is
# more efficient to use:
Expand Down
2 changes: 1 addition & 1 deletion stdlib/LinearAlgebra/src/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,7 @@ end

#To allow Bidiagonal's where the "dv" is Vector{T} and "ev" Vector{S},
#where T and S can be promoted
function LinearAlgebra.Bidiagonal(dv::Vector{T}, ev::Vector{S}, uplo::Symbol) where {T,S}
function Bidiagonal(dv::Vector{T}, ev::Vector{S}, uplo::Symbol) where {T,S}
TS = promote_type(T,S)
return Bidiagonal{TS,Vector{TS}}(dv, ev, uplo)
end
Expand Down
15 changes: 4 additions & 11 deletions stdlib/LinearAlgebra/src/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -76,9 +76,6 @@ adjoint(F::LU) = Adjoint(F)
transpose(F::LU) = Transpose(F)

# StridedMatrix
lu(A::StridedMatrix, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) =
lu!(copy_oftype(A, lutype(eltype(A))), pivot; check=check)

lu!(A::StridedMatrix{<:BlasFloat}; check::Bool = true) = lu!(A, RowMaximum(); check=check)
function lu!(A::StridedMatrix{T}, ::RowMaximum; check::Bool = true) where {T<:BlasFloat}
lpt = LAPACK.getrf!(A)
Expand All @@ -89,9 +86,6 @@ function lu!(A::StridedMatrix{<:BlasFloat}, pivot::NoPivot; check::Bool = true)
return generic_lufact!(A, pivot; check = check)
end

lu(A::HermOrSym, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) =
lu!(copy_oftype(A, lutype(eltype(A))), pivot; check=check)

function lu!(A::HermOrSym, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true)
copytri!(A.data, A.uplo, isa(A, Hermitian))
lu!(A.data, pivot; check = check)
Expand Down Expand Up @@ -282,13 +276,15 @@ true
```
"""
function lu(A::AbstractMatrix{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T}
S = lutype(T)
lu!(copy_to_array(A, S), pivot; check = check)
lu!(_lucopy(A, lutype(T)), pivot; check = check)
end
# TODO: remove for Julia v2.0
@deprecate lu(A::AbstractMatrix, ::Val{true}; check::Bool = true) lu(A, RowMaximum(); check=check)
@deprecate lu(A::AbstractMatrix, ::Val{false}; check::Bool = true) lu(A, NoPivot(); check=check)

_lucopy(A::AbstractMatrix, T) = copy_similar(A, T)
_lucopy(A::HermOrSym, T) = copy_oftype(A, T)
_lucopy(A::Tridiagonal, T) = copy_oftype(A, T)

lu(S::LU) = S
function lu(x::Number; check::Bool=true)
Expand Down Expand Up @@ -497,9 +493,6 @@ inv(A::LU{<:BlasFloat,<:StridedMatrix}) = inv!(copy(A))

# Tridiagonal

lu(A::Tridiagonal{T}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where T =
lu!(copy_oftype(A, lutype(T)), pivot; check = check)

# See dgttrf.f
function lu!(A::Tridiagonal{T,V}, pivot::Union{RowMaximum,NoPivot} = RowMaximum(); check::Bool = true) where {T,V}
# Extract values
Expand Down
22 changes: 6 additions & 16 deletions stdlib/LinearAlgebra/src/schur.jl
Original file line number Diff line number Diff line change
Expand Up @@ -97,7 +97,7 @@ julia> A
schur!(A::StridedMatrix{<:BlasFloat}) = Schur(LinearAlgebra.LAPACK.gees!('V', A)...)

"""
schur(A::StridedMatrix) -> F::Schur
schur(A) -> F::Schur
Computes the Schur factorization of the matrix `A`. The (quasi) triangular Schur factor can
be obtained from the `Schur` object `F` with either `F.Schur` or `F.T` and the
Expand Down Expand Up @@ -146,25 +146,20 @@ julia> t == F.T && z == F.Z && vals == F.values
true
```
"""
schur(A::StridedMatrix{<:BlasFloat}) = schur!(copy(A))
schur(A::StridedMatrix{T}) where T = schur!(copy_oftype(A, eigtype(T)))

schur(A::AbstractMatrix{T}) where {T} = schur!(copy_to_array(A, eigtype(T)))
schur(A::AbstractMatrix{T}) where {T} = schur!(copy_similar(A, eigtype(T)))
function schur(A::RealHermSymComplexHerm)
F = eigen(A; sortby=nothing)
return Schur(typeof(F.vectors)(Diagonal(F.values)), F.vectors, F.values)
end
function schur(A::Union{UnitUpperTriangular{T},UpperTriangular{T}}) where {T}
t = eigtype(T)
Z = Matrix{t}(undef, size(A)...)
copyto!(Z, A)
Z = copy_similar(A, t)
return Schur(Z, Matrix{t}(I, size(A)), convert(Vector{t}, diag(A)))
end
function schur(A::Union{UnitLowerTriangular{T},LowerTriangular{T}}) where {T}
t = eigtype(T)
# double flip the matrix A
Z = Matrix{t}(undef, size(A)...)
copyto!(Z, A)
Z = copy_similar(A, t)
reverse!(reshape(Z, :))
# construct "reverse" identity
n = size(A, 1)
Expand Down Expand Up @@ -338,7 +333,7 @@ schur!(A::StridedMatrix{T}, B::StridedMatrix{T}) where {T<:BlasFloat} =
GeneralizedSchur(LinearAlgebra.LAPACK.gges!('V', 'V', A, B)...)

"""
schur(A::StridedMatrix, B::StridedMatrix) -> F::GeneralizedSchur
schur(A, B) -> F::GeneralizedSchur
Computes the Generalized Schur (or QZ) factorization of the matrices `A` and `B`. The
(quasi) triangular Schur factors can be obtained from the `Schur` object `F` with `F.S`
Expand All @@ -350,14 +345,9 @@ generalized eigenvalues of `A` and `B` can be obtained with `F.α./F.β`.
Iterating the decomposition produces the components `F.S`, `F.T`, `F.Q`, `F.Z`,
`F.α`, and `F.β`.
"""
schur(A::StridedMatrix{T},B::StridedMatrix{T}) where {T<:BlasFloat} = schur!(copy(A),copy(B))
function schur(A::StridedMatrix{TA}, B::StridedMatrix{TB}) where {TA,TB}
S = promote_type(eigtype(TA), TB)
return schur!(copy_oftype(A, S), copy_oftype(B, S))
end
function schur(A::AbstractMatrix{TA}, B::AbstractMatrix{TB}) where {TA,TB}
S = promote_type(eigtype(TA), TB)
return schur!(copy_oftype(A, S), copy_oftype(B, S))
return schur!(copy_similar(A, S), copy_similar(B, S))
end

"""
Expand Down
9 changes: 9 additions & 0 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -50,6 +50,15 @@ Bidiagonal(A::AbstractTriangular) =
isbanded(A, -1, 0) ? Bidiagonal(diag(A, 0), diag(A, -1), :L) : # is lower bidiagonal
throw(ArgumentError("matrix cannot be represented as Bidiagonal"))

_lucopy(A::Bidiagonal, T) = copy_oftype(Tridiagonal(A), T)
_lucopy(A::Diagonal, T) = copy_oftype(Tridiagonal(A), T)
function _lucopy(A::SymTridiagonal, T)
du = copy_similar(_evview(A), T)
dl = copy.(transpose.(du))
d = copy_similar(A.dv, T)
return Tridiagonal(dl, d, du)
end

const ConvertibleSpecialMatrix = Union{Diagonal,Bidiagonal,SymTridiagonal,Tridiagonal,AbstractTriangular}
const PossibleTriangularMatrix = Union{Diagonal, Bidiagonal, AbstractTriangular}

Expand Down
36 changes: 22 additions & 14 deletions stdlib/LinearAlgebra/test/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -402,20 +402,28 @@ end
end
end

@testset "lu(A) has a fallback for abstract matrices (#40831)" begin
# check that lu works for some structured arrays
A0 = rand(5, 5)
@test lu(Diagonal(A0)) isa LU
@test Matrix(lu(Diagonal(A0))) Diagonal(A0)
@test lu(Bidiagonal(A0, :U)) isa LU
@test Matrix(lu(Bidiagonal(A0, :U))) Bidiagonal(A0, :U)

# lu(A) copies A and then invokes lu!, make sure that the most efficient
# implementation of lu! continues to be used
A1 = Tridiagonal(rand(2), rand(3), rand(2))
@test lu(A1) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
@test lu(A1, RowMaximum()) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
@test lu(A1, RowMaximum(); check = false) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
@testset "lu on *diagonal matrices" begin
dl = rand(3)
d = rand(4)
Bl = Bidiagonal(d, dl, :L)
Bu = Bidiagonal(d, dl, :U)
Tri = Tridiagonal(dl, d, dl)
Sym = SymTridiagonal(d, dl)
D = Diagonal(d)
b = ones(4)
B = rand(4,4)
for A in (Bl, Bu, Tri, Sym, D), pivot in (NoPivot(), RowMaximum())
@test A\b lu(A, pivot)\b
@test B/A B/lu(A, pivot)
@test B/A B/Matrix(A)
@test Matrix(lu(A, pivot)) A
@test @inferred(lu(A)) isa LU
if A isa Union{Bidiagonal, Diagonal, Tridiagonal, SymTridiagonal}
@test lu(A) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
@test lu(A, pivot) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
@test lu(A, pivot; check = false) isa LU{Float64, Tridiagonal{Float64, Vector{Float64}}}
end
end
end

end # module TestLU

0 comments on commit f15b4c3

Please sign in to comment.