Skip to content

Commit

Permalink
handle edge cases for +/- in special.jl
Browse files Browse the repository at this point in the history
  • Loading branch information
mcognetta committed Oct 3, 2021
1 parent 50a59c0 commit 6da5108
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 32 deletions.
7 changes: 6 additions & 1 deletion stdlib/LinearAlgebra/src/LinearAlgebra.jl
Original file line number Diff line number Diff line change
Expand Up @@ -265,7 +265,6 @@ end

@noinline throw_uplo() = throw(ArgumentError("uplo argument must be either :U (upper) or :L (lower)"))


"""
ldiv!(Y, A, B) -> Y
Expand Down Expand Up @@ -454,6 +453,12 @@ export ⋅, ×
_cut_B(x::AbstractVector, r::UnitRange) = length(x) > length(r) ? x[r] : x
_cut_B(X::AbstractMatrix, r::UnitRange) = size(X, 1) > length(r) ? X[r,:] : X

# SymTridiagonal ev can be the same length as dv, but the last element is
# ignored. However, some methods can fail if they read the entired ev
# rather than just the meaningful elements. This is a helper function
# for getting only the meaningful elements of ev. See #41089
_evview(S::SymTridiagonal) = @view S.ev[begin:length(S.dv) - 1]

## append right hand side with zeros if necessary
_zeros(::Type{T}, b::AbstractVector, n::Integer) where {T} = zeros(T, max(length(b), n))
_zeros(::Type{T}, B::AbstractMatrix, n::Integer) where {T} = zeros(T, max(size(B, 1), n), size(B, 2))
Expand Down
25 changes: 14 additions & 11 deletions stdlib/LinearAlgebra/src/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,9 @@ Tridiagonal(A::Bidiagonal) =

# conversions from SymTridiagonal to other special matrix types
Diagonal(A::SymTridiagonal) = Diagonal(A.dv)

# These can fail when ev has the same length as dv
# TODO: Revisit when a good solution for #42477 is found
Bidiagonal(A::SymTridiagonal) =
iszero(A.ev) ? Bidiagonal(A.dv, A.ev, :U) :
throw(ArgumentError("matrix cannot be represented as Bidiagonal"))
Expand Down Expand Up @@ -154,10 +157,10 @@ end

# this set doesn't have the aforementioned problem

+(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+B.ev, A.d+B.dv, A.du+B.ev)
-(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-B.ev, A.d-B.dv, A.du-B.ev)
+(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev+B.dl, A.dv+B.d, A.ev+B.du)
-(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(A.ev-B.dl, A.dv-B.d, A.ev-B.du)
+(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl+_evview(B), A.d+B.dv, A.du+_evview(B))
-(A::Tridiagonal, B::SymTridiagonal) = Tridiagonal(A.dl-_evview(B), A.d-B.dv, A.du-_evview(B))
+(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(_evview(A)+B.dl, A.dv+B.d, _evview(A)+B.du)
-(A::SymTridiagonal, B::Tridiagonal) = Tridiagonal(_evview(A)-B.dl, A.dv-B.d, _evview(A)-B.du)


function (+)(A::Diagonal, B::Tridiagonal)
Expand Down Expand Up @@ -202,22 +205,22 @@ end

function (+)(A::Bidiagonal, B::SymTridiagonal)
newdv = A.dv+B.dv
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(B.ev), A.dv+B.dv, A.ev+B.ev) : (A.ev+B.ev, A.dv+B.dv, typeof(newdv)(B.ev)))...)
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(_evview(B)), A.dv+B.dv, A.ev+_evview(B)) : (A.ev+_evview(B), A.dv+B.dv, typeof(newdv)(_evview(B))))...)
end

function (-)(A::Bidiagonal, B::SymTridiagonal)
newdv = A.dv-B.dv
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-B.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(-B.ev)))...)
Tridiagonal((A.uplo == 'U' ? (typeof(newdv)(-_evview(B)), newdv, A.ev-_evview(B)) : (A.ev-_evview(B), newdv, typeof(newdv)(-_evview(B))))...)
end

function (+)(A::SymTridiagonal, B::Bidiagonal)
newdv = A.dv+B.dv
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev+B.ev) : (A.ev+B.ev, newdv, typeof(newdv)(A.ev)))...)
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(_evview(A)), newdv, _evview(A)+B.ev) : (_evview(A)+B.ev, newdv, typeof(newdv)(_evview(A))))...)
end

function (-)(A::SymTridiagonal, B::Bidiagonal)
newdv = A.dv-B.dv
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(A.ev), newdv, A.ev-B.ev) : (A.ev-B.ev, newdv, typeof(newdv)(A.ev)))...)
Tridiagonal((B.uplo == 'U' ? (typeof(newdv)(_evview(A)), newdv, _evview(A)-B.ev) : (_evview(A)-B.ev, newdv, typeof(newdv)(_evview(A))))...)
end

# fixing uniform scaling problems from #28994
Expand Down Expand Up @@ -315,8 +318,8 @@ one(A::SymTridiagonal{T}) where T = SymTridiagonal(fill!(similar(A.dv, typeof(on
# equals and approx equals methods for structured matrices
# SymTridiagonal == Tridiagonal is already defined in tridiag.jl

# SymTridiagonal and Bidiagonal have the same field names
==(A::Diagonal, B::Union{SymTridiagonal, Bidiagonal}) = iszero(B.ev) && A.diag == B.dv
==(A::Diagonal, B::Bidiagonal) = iszero(B.ev) && A.diag == B.dv
==(A::Diagonal, B::SymTridiagonal) = iszero(_evview(B)) && A.diag == B.dv
==(B::Bidiagonal, A::Diagonal) = A == B

==(A::Diagonal, B::Tridiagonal) = iszero(B.dl) && iszero(B.du) && A.diag == B.d
Expand All @@ -331,5 +334,5 @@ function ==(A::Bidiagonal, B::Tridiagonal)
end
==(B::Tridiagonal, A::Bidiagonal) = A == B

==(A::Bidiagonal, B::SymTridiagonal) = iszero(B.ev) && iszero(A.ev) && A.dv == B.dv
==(A::Bidiagonal, B::SymTridiagonal) = iszero(_evview(A)) && iszero(A.ev) && A.dv == B.dv
==(B::SymTridiagonal, A::Bidiagonal) = A == B
34 changes: 14 additions & 20 deletions stdlib/LinearAlgebra/src/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -160,19 +160,19 @@ similar(S::SymTridiagonal, ::Type{T}) where {T} = SymTridiagonal(similar(S.dv, T
# similar(S::SymTridiagonal, ::Type{T}, dims::Union{Dims{1},Dims{2}}) where {T} = spzeros(T, dims...)

copyto!(dest::SymTridiagonal, src::SymTridiagonal) =
(copyto!(dest.dv, src.dv); copyto!(dest.ev, src.ev); dest)
(copyto!(dest.dv, src.dv); copyto!(dest.ev, _evview(src)); dest)

#Elementary operations
for func in (:conj, :copy, :real, :imag)
@eval ($func)(M::SymTridiagonal) = SymTridiagonal(($func)(M.dv), ($func)(M.ev))
@eval ($func)(M::SymTridiagonal) = SymTridiagonal(($func)(M.dv), ($func)(_evview(M)))
end

transpose(S::SymTridiagonal) = S
adjoint(S::SymTridiagonal{<:Real}) = S
adjoint(S::SymTridiagonal) = Adjoint(S)
Base.copy(S::Adjoint{<:Any,<:SymTridiagonal}) = SymTridiagonal(map(x -> copy.(adjoint.(x)), (S.parent.dv, S.parent.ev))...)

ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(@view S.ev[begin:length(S.dv) - 1])
ishermitian(S::SymTridiagonal) = isreal(S.dv) && isreal(_evview(S))
issymmetric(S::SymTridiagonal) = true

function diag(M::SymTridiagonal{<:Number}, n::Integer=0)
Expand All @@ -182,7 +182,7 @@ function diag(M::SymTridiagonal{<:Number}, n::Integer=0)
if absn == 0
return copyto!(similar(M.dv, length(M.dv)), M.dv)
elseif absn == 1
return copyto!(similar(M.ev, length(M.ev)), M.ev)
return copyto!(similar(M.ev, length(M.dv)-1), _evview(M))
elseif absn <= size(M,1)
return fill!(similar(M.dv, size(M,1)-absn), 0)
else
Expand All @@ -196,9 +196,9 @@ function diag(M::SymTridiagonal, n::Integer=0)
if n == 0
return copyto!(similar(M.dv, length(M.dv)), symmetric.(M.dv, :U))
elseif n == 1
return copyto!(similar(M.ev, length(M.ev)), M.ev)
return copyto!(similar(M.ev, length(M.dv)-1), _evview(M))
elseif n == -1
return copyto!(similar(M.ev, length(M.ev)), transpose.(M.ev))
return copyto!(similar(M.ev, length(M.dv)-1), transpose.(_evview(M)))
elseif n <= size(M,1)
throw(ArgumentError("requested diagonal contains undefined zeros of an array type"))
else
Expand All @@ -207,14 +207,14 @@ function diag(M::SymTridiagonal, n::Integer=0)
end
end

+(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, A.ev+B.ev)
-(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, A.ev-B.ev)
+(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv+B.dv, _evview(A)+_evview(B))
-(A::SymTridiagonal, B::SymTridiagonal) = SymTridiagonal(A.dv-B.dv, _evview(A)-_evview(B))
-(A::SymTridiagonal) = SymTridiagonal(-A.dv, -A.ev)
*(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv*B, A.ev*B)
*(B::Number, A::SymTridiagonal) = SymTridiagonal(B*A.dv, B*A.ev)
/(A::SymTridiagonal, B::Number) = SymTridiagonal(A.dv/B, A.ev/B)
\(B::Number, A::SymTridiagonal) = SymTridiagonal(B\A.dv, B\A.ev)
==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (A.ev==B.ev)
==(A::SymTridiagonal, B::SymTridiagonal) = (A.dv==B.dv) && (_evview(A)==_evview(B))

@inline mul!(A::StridedVecOrMat, B::SymTridiagonal, C::StridedVecOrMat,
alpha::Number, beta::Number) =
Expand Down Expand Up @@ -361,25 +361,19 @@ end

# tril and triu

# SymTridiagonal ev can be the same length as dv, but the last element is
# ignored. However, some methods can fail if they read the entired ev
# rather than just the meaningful elements. This is a helper function
# for checking if ev is semantically zero. See #41089
_isevzero(M::SymTridiagonal) = iszero(@view M.ev[begin:length(M.dv) - 1])

function istriu(M::SymTridiagonal, k::Integer=0)
if k <= -1
return true
elseif k == 0
return _isevzero(M)
return iszero(_evview(M))
else # k >= 1
return _isevzero(M) && iszero(M.dv)
return iszero(_evview(M)) && iszero(M.dv)
end
end
istril(M::SymTridiagonal, k::Integer) = istriu(M, -k)
iszero(M::SymTridiagonal) = _isevzero(M) && iszero(M.dv)
isone(M::SymTridiagonal) = _isevzero(M) && all(isone, M.dv)
isdiag(M::SymTridiagonal) = _isevzero(M)
iszero(M::SymTridiagonal) = iszero(_evview(M)) && iszero(M.dv)
isone(M::SymTridiagonal) = iszero(_evview(M)) && all(isone, M.dv)
isdiag(M::SymTridiagonal) = iszero(_evview(M))


function tril!(M::SymTridiagonal, k::Integer=0)
Expand Down
23 changes: 23 additions & 0 deletions stdlib/LinearAlgebra/test/special.jl
Original file line number Diff line number Diff line change
Expand Up @@ -450,4 +450,27 @@ end
@test A*Sym A*Matrix(Sym)
end

@testset "Ops on SymTridiagonal ev has the same length as dv" begin
x = rand(3)
y = rand(3)
z = rand(2)

S = SymTridiagonal(x, y)
T = Tridiagonal(z, x, z)
Bu = Bidiagonal(x, z, :U)
Bl = Bidiagonal(x, z, :L)

Ms = Matrix(S)
Mt = Matrix(T)
Mbu = Matrix(Bu)
Mbl = Matrix(Bl)

@test S + T Ms + Mt
@test T + S Mt + Ms
@test S + Bu Ms + Mbu
@test Bu + S Mbu + Ms
@test S + Bl Ms + Mbl
@test Bl + S Mbl + Ms
end

end # module TestSpecial

0 comments on commit 6da5108

Please sign in to comment.