Skip to content

Commit

Permalink
LinearAlgebra: specialize is{hermitian,symmetric} for {Sym,}Tridiagonal
Browse files Browse the repository at this point in the history
  • Loading branch information
anaveragehuman committed Jun 1, 2021
1 parent acdffeb commit 1044cb3
Show file tree
Hide file tree
Showing 2 changed files with 33 additions and 1 deletion.
10 changes: 9 additions & 1 deletion stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -170,6 +170,10 @@ adjoint(S::SymTridiagonal) = Adjoint(S)
Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...)
Base.copy(S::Transpose{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(transpose.(x)), (S.parent.dv, S.parent.ev))...)

ishermitian(S::SymTridiagonal{<:Real}) = true
ishermitian(S::SymTridiagonal{<:Complex}) = all(x -> imag(x) == 0, S.dv) && all(x -> imag(x) == 0, S.ev)
issymmetric(S::SymTridiagonal) = true

function diag(M::SymTridiagonal{<:Number}, n::Integer=0)
# every branch call similar(..., ::Int) to make sure the
# same vector type is returned independent of n
Expand Down Expand Up @@ -608,6 +612,10 @@ transpose(S::Tridiagonal{<:Number}) = Tridiagonal(S.du, S.d, S.dl)
Base.copy(aS::Adjoint{<:Any,<:Tridiagonal}) = (S = aS.parent; Tridiagonal(map(x -> copy.(adjoint.(x)), (S.du, S.d, S.dl))...))
Base.copy(tS::Transpose{<:Any,<:Tridiagonal}) = (S = tS.parent; Tridiagonal(map(x -> copy.(transpose.(x)), (S.du, S.d, S.dl))...))

ishermitian(S::Tridiagonal{<:Real}) = S.du == S.dl
ishermitian(S::Tridiagonal{<:Complex}) = all(x -> imag(x) == 0, S.d) && S.du == adjoint(S.dl)[1,:]
issymmetric(S::Tridiagonal) = S.du == S.dl

\(A::Adjoint{<:Any,<:Tridiagonal}, B::Adjoint{<:Any,<:StridedVecOrMat}) = copy(A) \ copy(B)

function diag(M::Tridiagonal, n::Integer=0)
Expand Down Expand Up @@ -747,7 +755,7 @@ det(A::Tridiagonal) = det_usmani(A.dl, A.d, A.du)
AbstractMatrix{T}(M::Tridiagonal) where {T} = Tridiagonal{T}(M)
Tridiagonal{T}(M::SymTridiagonal{T}) where {T} = Tridiagonal(M)
function SymTridiagonal{T}(M::Tridiagonal) where T
if M.dl == M.du
if issymmetric(M)
return SymTridiagonal{T}(convert(AbstractVector{T},M.d), convert(AbstractVector{T},M.dl))
else
throw(ArgumentError("Tridiagonal is not symmetric, cannot convert to SymTridiagonal"))
Expand Down
24 changes: 24 additions & 0 deletions stdlib/LinearAlgebra/test/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -617,4 +617,28 @@ end
@test copy(Sc')\b == F'\b
end

@testset "symmetric and hermitian tridiagonals" begin
A = [im 0; 0 -im]
@test issymmetric(A)
@test !ishermitian(A)

# real
A = SymTridiagonal(randn(5), randn(4))
@test issymmetric(A)
@test ishermitian(A)

A = Tridiagonal(A.ev, A.dv, A.ev .+ 1)
@test !issymmetric(A)
@test !ishermitian(A)

# complex
S = SymTridiagonal(randn(5) .+ 1im, randn(4) .+ 1im)
@test issymmetric(S)
@test !ishermitian(S)

S = Tridiagonal(S.ev, zero(S.dv), adjoint(S.ev)[1,:])
@test !issymmetric(S)
@test ishermitian(S)
end

end # module TestTridiagonal

0 comments on commit 1044cb3

Please sign in to comment.