diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index b9737bf36d0c52..cb4cc70256d577 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -265,7 +265,6 @@ end @noinline throw_uplo() = throw(ArgumentError("uplo argument must be either :U (upper) or :L (lower)")) - """ ldiv!(Y, A, B) -> Y @@ -454,6 +453,12 @@ export ⋅, × _cut_B(x::AbstractVector, r::UnitRange) = length(x) > length(r) ? x[r] : x _cut_B(X::AbstractMatrix, r::UnitRange) = size(X, 1) > length(r) ? X[r,:] : X +# SymTridiagonal ev can be the same length as dv, but the last element is +# ignored. However, some methods can fail if they read the entired ev +# rather than just the meaningful elements. This is a helper function +# for getting only the meaningful elements of ev. See #41089 +_evview(S::SymTridiagonal) = @view S.ev[begin:length(S.dv) - 1] + ## append right hand side with zeros if necessary _zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n)) _zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2)) diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 5c25c0993e9cc6..5bee27fc782483 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -28,6 +28,9 @@ Tridiagonal(A::Bidiagonal) = # conversions from SymTridiagonal to other special matrix types Diagonal(A::SymTridiagonal) = Diagonal(A.dv) + +# These can fail when ev has the same length as dv +# TODO: Revisit when a good solution for #42477 is found Bidiagonal(A::SymTridiagonal) = iszero(A.ev) ? Bidiagonal(A.dv, A.ev, :U) : throw(ArgumentError("matrix cannot be represented as Bidiagonal")) @@ -154,10 +157,10 @@ end # this set doesn't have the aforementioned problem -+(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+B.ev, A.d+B.dv, A.du+B.ev) --(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-B.ev, A.d-B.dv, A.du-B.ev) -+(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev+B.dl, A.dv+B.d, A.ev+B.du) --(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev-B.dl, A.dv-B.d, A.ev-B.du) ++(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+_evview(B), A.d+B.dv, A.du+_evview(B)) +-(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-_evview(B), A.d-B.dv, A.du-_evview(B)) ++(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(_evview(A)+B.dl, A.dv+B.d, _evview(A)+B.du) +-(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(_evview(A)-B.dl, A.dv-B.d, _evview(A)-B.du) function (+)(A::Diagonal, B::Tridiagonal) @@ -202,22 +205,22 @@ end function (+)(A::Bidiagonal, B::SymTridiagonal) newdv = A.dv+B.dv - Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(B.ev), A.dv+B.dv, A.ev+B.ev) : (A.ev+B.ev, A.dv+B.dv, typeof(newdv)(B.ev)))...) + Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(_evview(B)), A.dv+B.dv, A.ev+_evview(B)) : (A.ev+_evview(B), A.dv+B.dv, typeof(newdv)(_evview(B))))...) end function (-)(A::Bidiagonal, B::SymTridiagonal) newdv = A.dv-B.dv - Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-B.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(-B.ev)))...) + Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-_evview(B)), newdv, A.ev-_evview(B)) : (A.ev-_evview(B), newdv, typeof(newdv)(-_evview(B))))...) end function (+)(A::SymTridiagonal, B::Bidiagonal) newdv = A.dv+B.dv - Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev+B.ev) : (A.ev+B.ev, newdv, typeof(newdv)(A.ev)))...) + Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(_evview(A)), newdv, _evview(A)+B.ev) : (_evview(A)+B.ev, newdv, typeof(newdv)(_evview(A))))...) end function (-)(A::SymTridiagonal, B::Bidiagonal) newdv = A.dv-B.dv - Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(A.ev)))...) + Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(_evview(A)), newdv, _evview(A)-B.ev) : (_evview(A)-B.ev, newdv, typeof(newdv)(_evview(A))))...) end # fixing uniform scaling problems from #28994 @@ -315,8 +318,8 @@ one(A::SymTridiagonal{T}) where T = SymTridiagonal(fill!(similar(A.dv, typeof(on # equals and approx equals methods for structured matrices # SymTridiagonal == Tridiagonal is already defined in tridiag.jl -# SymTridiagonal and Bidiagonal have the same field names -==(A::Diagonal, B::Union{SymTridiagonal, Bidiagonal}) = iszero(B.ev) && A.diag == B.dv +==(A::Diagonal, B::Bidiagonal) = iszero(B.ev) && A.diag == B.dv +==(A::Diagonal, B::SymTridiagonal) = iszero(_evview(B)) && A.diag == B.dv ==(B::Bidiagonal, A::Diagonal) = A == B ==(A::Diagonal, B::Tridiagonal) = iszero(B.dl) && iszero(B.du) && A.diag == B.d @@ -331,5 +334,5 @@ function ==(A::Bidiagonal, B::Tridiagonal) end ==(B::Tridiagonal, A::Bidiagonal) = A == B -==(A::Bidiagonal, B::SymTridiagonal) = iszero(B.ev) && iszero(A.ev) && A.dv == B.dv +==(A::Bidiagonal, B::SymTridiagonal) = iszero(_evview(A)) && iszero(A.ev) && A.dv == B.dv ==(B::SymTridiagonal, A::Bidiagonal) = A == B diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index e781205116cddb..c3528bb1e73a6f 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -160,11 +160,11 @@ similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T # similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzeros(T, dims...) copyto!(dest::SymTridiagonal, src::SymTridiagonal) = - (copyto!(dest.dv, src.dv); copyto!(dest.ev, src.ev); dest) + (copyto!(dest.dv, src.dv); copyto!(dest.ev, _evview(src)); dest) #Elementary operations for func in (:conj, :copy, :real, :imag) - @eval ($func)(M::SymTridiagonal) = SymTridiagonal(($func)(M.dv), ($func)(M.ev)) + @eval ($func)(M::SymTridiagonal) = SymTridiagonal(($func)(M.dv), ($func)(_evview(M))) end transpose(S::SymTridiagonal) = S @@ -172,7 +172,7 @@ adjoint(S::SymTridiagonal{<:Real}) = S adjoint(S::SymTridiagonal) = Adjoint(S) Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...) -ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(@view S.ev[begin:length(S.dv) - 1]) +ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(_evview(S)) issymmetric(S::SymTridiagonal) = true function diag(M::SymTridiagonal{<:Number}, n::Integer=0) @@ -182,7 +182,7 @@ function diag(M::SymTridiagonal{<:Number}, n::Integer=0) if absn == 0 return copyto!(similar(M.dv, length(M.dv)), M.dv) elseif absn == 1 - return copyto!(similar(M.ev, length(M.ev)), M.ev) + return copyto!(similar(M.ev, length(M.dv)-1), _evview(M)) elseif absn <= size(M,1) return fill!(similar(M.dv, size(M,1)-absn), 0) else @@ -196,9 +196,9 @@ function diag(M::SymTridiagonal, n::Integer=0) if n == 0 return copyto!(similar(M.dv, length(M.dv)), symmetric.(M.dv, :U)) elseif n == 1 - return copyto!(similar(M.ev, length(M.ev)), M.ev) + return copyto!(similar(M.ev, length(M.dv)-1), _evview(M)) elseif n == -1 - return copyto!(similar(M.ev, length(M.ev)), transpose.(M.ev)) + return copyto!(similar(M.ev, length(M.dv)-1), transpose.(_evview(M))) elseif n <= size(M,1) throw(ArgumentError("requested diagonal contains undefined zeros of an array type")) else @@ -207,14 +207,14 @@ function diag(M::SymTridiagonal, n::Integer=0) end end -+(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, A.ev+B.ev) --(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, A.ev-B.ev) ++(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, _evview(A)+_evview(B)) +-(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, _evview(A)-_evview(B)) -(A::SymTridiagonal) = SymTridiagonal(-A.dv, -A.ev) *(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv*B, A.ev*B) *(B::Number, A::SymTridiagonal) = SymTridiagonal(B*A.dv, B*A.ev) /(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv/B, A.ev/B) \(B::Number, A::SymTridiagonal) = SymTridiagonal(B\A.dv, B\A.ev) -==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (A.ev==B.ev) +==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (_evview(A)==_evview(B)) @inline mul!(A::StridedVecOrMat, B::SymTridiagonal, C::StridedVecOrMat, alpha::Number, beta::Number) = @@ -361,25 +361,19 @@ end # tril and triu -# SymTridiagonal ev can be the same length as dv, but the last element is -# ignored. However, some methods can fail if they read the entired ev -# rather than just the meaningful elements. This is a helper function -# for checking if ev is semantically zero. See #41089 -_isevzero(M::SymTridiagonal) = iszero(@view M.ev[begin:length(M.dv) - 1]) - function istriu(M::SymTridiagonal, k::Integer=0) if k <= -1 return true elseif k == 0 - return _isevzero(M) + return iszero(_evview(M)) else # k >= 1 - return _isevzero(M) && iszero(M.dv) + return iszero(_evview(M)) && iszero(M.dv) end end istril(M::SymTridiagonal, k::Integer) = istriu(M, -k) -iszero(M::SymTridiagonal) = _isevzero(M) && iszero(M.dv) -isone(M::SymTridiagonal) = _isevzero(M) && all(isone, M.dv) -isdiag(M::SymTridiagonal) = _isevzero(M) +iszero(M::SymTridiagonal) = iszero(_evview(M)) && iszero(M.dv) +isone(M::SymTridiagonal) = iszero(_evview(M)) && all(isone, M.dv) +isdiag(M::SymTridiagonal) = iszero(_evview(M)) function tril!(M::SymTridiagonal, k::Integer=0) diff --git a/stdlib/LinearAlgebra/test/special.jl b/stdlib/LinearAlgebra/test/special.jl index bf4c8dee589775..e0c5f87111b07a 100644 --- a/stdlib/LinearAlgebra/test/special.jl +++ b/stdlib/LinearAlgebra/test/special.jl @@ -450,4 +450,27 @@ end @test A*Sym ≈ A*Matrix(Sym) end +@testset "Ops on SymTridiagonal ev has the same length as dv" begin + x = rand(3) + y = rand(3) + z = rand(2) + + S = SymTridiagonal(x, y) + T = Tridiagonal(z, x, z) + Bu = Bidiagonal(x, z, :U) + Bl = Bidiagonal(x, z, :L) + + Ms = Matrix(S) + Mt = Matrix(T) + Mbu = Matrix(Bu) + Mbl = Matrix(Bl) + + @test S + T ≈ Ms + Mt + @test T + S ≈ Mt + Ms + @test S + Bu ≈ Ms + Mbu + @test Bu + S ≈ Mbu + Ms + @test S + Bl ≈ Ms + Mbl + @test Bl + S ≈ Mbl + Ms +end + end # module TestSpecial