From d199de0dba6e9d6fd6a5866087cb76a6e4001e2e Mon Sep 17 00:00:00 2001 From: Marco Cognetta Date: Sat, 2 Oct 2021 23:21:28 -0500 Subject: [PATCH 1/2] squash ev fixes --- stdlib/LinearAlgebra/src/LinearAlgebra.jl | 7 +++++- stdlib/LinearAlgebra/src/special.jl | 25 ++++++++++--------- stdlib/LinearAlgebra/src/tridiag.jl | 29 ++++++++++++----------- stdlib/LinearAlgebra/test/special.jl | 23 ++++++++++++++++++ stdlib/LinearAlgebra/test/tridiag.jl | 19 +++++++++++++++ 5 files changed, 77 insertions(+), 26 deletions(-) diff --git a/stdlib/LinearAlgebra/src/LinearAlgebra.jl b/stdlib/LinearAlgebra/src/LinearAlgebra.jl index b9737bf36d0c5..cb4cc70256d57 100644 --- a/stdlib/LinearAlgebra/src/LinearAlgebra.jl +++ b/stdlib/LinearAlgebra/src/LinearAlgebra.jl @@ -265,7 +265,6 @@ end @noinline throw_uplo() = throw(ArgumentError("uplo argument must be either :U (upper) or :L (lower)")) - """ ldiv!(Y, A, B) -> Y @@ -454,6 +453,12 @@ export ⋅, × _cut_B(x::AbstractVector, r::UnitRange) = length(x) > length(r) ? x[r] : x _cut_B(X::AbstractMatrix, r::UnitRange) = size(X, 1) > length(r) ? X[r,:] : X +# SymTridiagonal ev can be the same length as dv, but the last element is +# ignored. However, some methods can fail if they read the entired ev +# rather than just the meaningful elements. This is a helper function +# for getting only the meaningful elements of ev. See #41089 +_evview(S::SymTridiagonal) = @view S.ev[begin:length(S.dv) - 1] + ## append right hand side with zeros if necessary _zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n)) _zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2)) diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index 5c25c0993e9cc..9b911e40ed75f 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -28,6 +28,9 @@ Tridiagonal(A::Bidiagonal) = # conversions from SymTridiagonal to other special matrix types Diagonal(A::SymTridiagonal) = Diagonal(A.dv) + +# These can fail when ev has the same length as dv +# TODO: Revisit when a good solution for #42477 is found Bidiagonal(A::SymTridiagonal) = iszero(A.ev) ? Bidiagonal(A.dv, A.ev, :U) : throw(ArgumentError("matrix cannot be represented as Bidiagonal")) @@ -154,10 +157,10 @@ end # this set doesn't have the aforementioned problem -+(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+B.ev, A.d+B.dv, A.du+B.ev) --(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-B.ev, A.d-B.dv, A.du-B.ev) -+(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev+B.dl, A.dv+B.d, A.ev+B.du) --(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev-B.dl, A.dv-B.d, A.ev-B.du) ++(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+_evview(B), A.d+B.dv, A.du+_evview(B)) +-(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-_evview(B), A.d-B.dv, A.du-_evview(B)) ++(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(_evview(A)+B.dl, A.dv+B.d, _evview(A)+B.du) +-(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(_evview(A)-B.dl, A.dv-B.d, _evview(A)-B.du) function (+)(A::Diagonal, B::Tridiagonal) @@ -202,22 +205,22 @@ end function (+)(A::Bidiagonal, B::SymTridiagonal) newdv = A.dv+B.dv - Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(B.ev), A.dv+B.dv, A.ev+B.ev) : (A.ev+B.ev, A.dv+B.dv, typeof(newdv)(B.ev)))...) + Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(_evview(B)), A.dv+B.dv, A.ev+_evview(B)) : (A.ev+_evview(B), A.dv+B.dv, typeof(newdv)(_evview(B))))...) end function (-)(A::Bidiagonal, B::SymTridiagonal) newdv = A.dv-B.dv - Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-B.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(-B.ev)))...) + Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-_evview(B)), newdv, A.ev-_evview(B)) : (A.ev-_evview(B), newdv, typeof(newdv)(-_evview(B))))...) end function (+)(A::SymTridiagonal, B::Bidiagonal) newdv = A.dv+B.dv - Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev+B.ev) : (A.ev+B.ev, newdv, typeof(newdv)(A.ev)))...) + Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(_evview(A)), newdv, _evview(A)+B.ev) : (_evview(A)+B.ev, newdv, typeof(newdv)(_evview(A))))...) end function (-)(A::SymTridiagonal, B::Bidiagonal) newdv = A.dv-B.dv - Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(A.ev)))...) + Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(_evview(A)), newdv, _evview(A)-B.ev) : (_evview(A)-B.ev, newdv, typeof(newdv)(_evview(A))))...) end # fixing uniform scaling problems from #28994 @@ -315,8 +318,8 @@ one(A::SymTridiagonal{T}) where T = SymTridiagonal(fill!(similar(A.dv, typeof(on # equals and approx equals methods for structured matrices # SymTridiagonal == Tridiagonal is already defined in tridiag.jl -# SymTridiagonal and Bidiagonal have the same field names -==(A::Diagonal, B::Union{SymTridiagonal, Bidiagonal}) = iszero(B.ev) && A.diag == B.dv +==(A::Diagonal, B::Bidiagonal) = iszero(B.ev) && A.diag == B.dv +==(A::Diagonal, B::SymTridiagonal) = iszero(_evview(B)) && A.diag == B.dv ==(B::Bidiagonal, A::Diagonal) = A == B ==(A::Diagonal, B::Tridiagonal) = iszero(B.dl) && iszero(B.du) && A.diag == B.d @@ -331,5 +334,5 @@ function ==(A::Bidiagonal, B::Tridiagonal) end ==(B::Tridiagonal, A::Bidiagonal) = A == B -==(A::Bidiagonal, B::SymTridiagonal) = iszero(B.ev) && iszero(A.ev) && A.dv == B.dv +==(A::Bidiagonal, B::SymTridiagonal) = iszero(_evview(B)) && iszero(A.ev) && A.dv == B.dv ==(B::SymTridiagonal, A::Bidiagonal) = A == B diff --git a/stdlib/LinearAlgebra/src/tridiag.jl b/stdlib/LinearAlgebra/src/tridiag.jl index cc551e4911acf..e3df66c2de590 100644 --- a/stdlib/LinearAlgebra/src/tridiag.jl +++ b/stdlib/LinearAlgebra/src/tridiag.jl @@ -160,7 +160,7 @@ similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T # similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzeros(T, dims...) copyto!(dest::SymTridiagonal, src::SymTridiagonal) = - (copyto!(dest.dv, src.dv); copyto!(dest.ev, src.ev); dest) + (copyto!(dest.dv, src.dv); copyto!(dest.ev, _evview(src)); dest) #Elementary operations for func in (:conj, :copy, :real, :imag) @@ -172,7 +172,7 @@ adjoint(S::SymTridiagonal{<:Real}) = S adjoint(S::SymTridiagonal) = Adjoint(S) Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...) -ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(@view S.ev[begin:length(S.dv) - 1]) +ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(_evview(S)) issymmetric(S::SymTridiagonal) = true function diag(M::SymTridiagonal{<:Number}, n::Integer=0) @@ -182,7 +182,7 @@ function diag(M::SymTridiagonal{<:Number}, n::Integer=0) if absn == 0 return copyto!(similar(M.dv, length(M.dv)), M.dv) elseif absn == 1 - return copyto!(similar(M.ev, length(M.ev)), M.ev) + return copyto!(similar(M.ev, length(M.dv)-1), _evview(M)) elseif absn <= size(M,1) return fill!(similar(M.dv, size(M,1)-absn), 0) else @@ -196,9 +196,9 @@ function diag(M::SymTridiagonal, n::Integer=0) if n == 0 return copyto!(similar(M.dv, length(M.dv)), symmetric.(M.dv, :U)) elseif n == 1 - return copyto!(similar(M.ev, length(M.ev)), M.ev) + return copyto!(similar(M.ev, length(M.dv)-1), _evview(M)) elseif n == -1 - return copyto!(similar(M.ev, length(M.ev)), transpose.(M.ev)) + return copyto!(similar(M.ev, length(M.dv)-1), transpose.(_evview(M))) elseif n <= size(M,1) throw(ArgumentError("requested diagonal contains undefined zeros of an array type")) else @@ -207,14 +207,14 @@ function diag(M::SymTridiagonal, n::Integer=0) end end -+(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, A.ev+B.ev) --(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, A.ev-B.ev) ++(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, _evview(A)+_evview(B)) +-(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, _evview(A)-_evview(B)) -(A::SymTridiagonal) = SymTridiagonal(-A.dv, -A.ev) *(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv*B, A.ev*B) *(B::Number, A::SymTridiagonal) = SymTridiagonal(B*A.dv, B*A.ev) /(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv/B, A.ev/B) \(B::Number, A::SymTridiagonal) = SymTridiagonal(B\A.dv, B\A.ev) -==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (A.ev==B.ev) +==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (_evview(A)==_evview(B)) @inline mul!(A::StridedVecOrMat, B::SymTridiagonal, C::StridedVecOrMat, alpha::Number, beta::Number) = @@ -359,21 +359,22 @@ function svdvals!(A::SymTridiagonal) return sort!(map!(abs, vals, vals); rev=true) end -#tril and triu +# tril and triu function istriu(M::SymTridiagonal, k::Integer=0) if k <= -1 return true elseif k == 0 - return iszero(M.ev) + return iszero(_evview(M)) else # k >= 1 - return iszero(M.ev) && iszero(M.dv) + return iszero(_evview(M)) && iszero(M.dv) end end istril(M::SymTridiagonal, k::Integer) = istriu(M, -k) -iszero(M::SymTridiagonal) = iszero(M.ev) && iszero(M.dv) -isone(M::SymTridiagonal) = iszero(M.ev) && all(isone, M.dv) -isdiag(M::SymTridiagonal) = iszero(M.ev) +iszero(M::SymTridiagonal) = iszero(_evview(M)) && iszero(M.dv) +isone(M::SymTridiagonal) = iszero(_evview(M)) && all(isone, M.dv) +isdiag(M::SymTridiagonal) = iszero(_evview(M)) + function tril!(M::SymTridiagonal, k::Integer=0) n = length(M.dv) diff --git a/stdlib/LinearAlgebra/test/special.jl b/stdlib/LinearAlgebra/test/special.jl index bf4c8dee58977..e0c5f87111b07 100644 --- a/stdlib/LinearAlgebra/test/special.jl +++ b/stdlib/LinearAlgebra/test/special.jl @@ -450,4 +450,27 @@ end @test A*Sym ≈ A*Matrix(Sym) end +@testset "Ops on SymTridiagonal ev has the same length as dv" begin + x = rand(3) + y = rand(3) + z = rand(2) + + S = SymTridiagonal(x, y) + T = Tridiagonal(z, x, z) + Bu = Bidiagonal(x, z, :U) + Bl = Bidiagonal(x, z, :L) + + Ms = Matrix(S) + Mt = Matrix(T) + Mbu = Matrix(Bu) + Mbl = Matrix(Bl) + + @test S + T ≈ Ms + Mt + @test T + S ≈ Mt + Ms + @test S + Bu ≈ Ms + Mbu + @test Bu + S ≈ Mbu + Ms + @test S + Bl ≈ Ms + Mbl + @test Bl + S ≈ Mbl + Ms +end + end # module TestSpecial diff --git a/stdlib/LinearAlgebra/test/tridiag.jl b/stdlib/LinearAlgebra/test/tridiag.jl index 31e107ddc0e3c..39c2b64bab070 100644 --- a/stdlib/LinearAlgebra/test/tridiag.jl +++ b/stdlib/LinearAlgebra/test/tridiag.jl @@ -164,6 +164,19 @@ end @test !isdiag(Tridiagonal(dl,d,zerosdu)) @test !isdiag(Tridiagonal(zerosdl,d,du)) @test !isdiag(Tridiagonal(dl,d,du)) + + # Test methods that could fail due to dv and ev having the same length + # see #41089 + + badev = zero(d) + badev[end] = 1 + S = SymTridiagonal(d, badev) + + @test istriu(S, -2) + @test istriu(S, 0) + @test !istriu(S, 2) + + @test isdiag(S) end @testset "iszero and isone" begin @@ -190,6 +203,12 @@ end @test isone(Sone) @test !iszero(Smix) @test !isone(Smix) + + badev = zeros(elty, 3) + badev[end] = 1 + + @test isone(SymTridiagonal(ones(elty, 3), badev)) + @test iszero(SymTridiagonal(zeros(elty, 3), badev)) end @testset for mat_type in (Tridiagonal, SymTridiagonal) From 15a2580e4b20c50bf64f84b19406693617a16c13 Mon Sep 17 00:00:00 2001 From: Marco Date: Thu, 7 Oct 2021 21:55:57 -0700 Subject: [PATCH 2/2] remove whitespace --- stdlib/LinearAlgebra/src/special.jl | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/stdlib/LinearAlgebra/src/special.jl b/stdlib/LinearAlgebra/src/special.jl index a9d2b8dd947fa..b71e588b87feb 100644 --- a/stdlib/LinearAlgebra/src/special.jl +++ b/stdlib/LinearAlgebra/src/special.jl @@ -318,12 +318,12 @@ one(A::SymTridiagonal{T}) where T = SymTridiagonal(fill!(similar(A.dv, typeof(on zero(D::Diagonal) = Diagonal(zero.(D.diag)) oneunit(D::Diagonal) = Diagonal(oneunit.(D.diag)) - + # equals and approx equals methods for structured matrices # SymTridiagonal == Tridiagonal is already defined in tridiag.jl ==(A::Diagonal, B::Bidiagonal) = iszero(B.ev) && A.diag == B.dv -==(A::Diagonal, B::SymTridiagonal) = iszero(_evview(B)) && A.diag == B.dv +==(A::Diagonal, B::SymTridiagonal) = iszero(_evview(B)) && A.diag == B.dv ==(B::Bidiagonal, A::Diagonal) = A == B ==(A::Diagonal, B::Tridiagonal) = iszero(B.dl) && iszero(B.du) && A.diag == B.d ==(B::Tridiagonal, A::Diagonal) = A == B