diff --git a/base/linalg/diagonal.jl b/base/linalg/diagonal.jl index 16cf79b6e4e74e..c65496959c5d1d 100644 --- a/base/linalg/diagonal.jl +++ b/base/linalg/diagonal.jl @@ -55,8 +55,29 @@ isposdef(D::Diagonal) = all(D.diag .> 0) factorize(D::Diagonal) = D -tril(D::Diagonal,i::Integer=0) = i == 0 ? D : zeros(D) -triu(D::Diagonal,i::Integer=0) = i == 0 ? D : zeros(D) +istriu(D::Diagonal) = true +istril(D::Diagonal) = true +function triu!(D::Diagonal,k::Integer=0) + n = size(D,1) + if abs(k) > n + throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)")) + elseif k == 0 + return D + else + return zeros(D) + end +end + +function tril!(D::Diagonal,k::Integer=0) + n = size(D,1) + if abs(k) > n + throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)")) + elseif k == 0 + return D + else + return zeros(D) + end +end ==(Da::Diagonal, Db::Diagonal) = Da.diag == Db.diag -(A::Diagonal)=Diagonal(-A.diag) diff --git a/test/linalg/diagonal.jl b/test/linalg/diagonal.jl index 0d84b0e2cca3e1..6518cd1e202be6 100644 --- a/test/linalg/diagonal.jl +++ b/test/linalg/diagonal.jl @@ -84,10 +84,12 @@ for relty in (Float32, Float64, BigFloat), elty in (relty, Complex{relty}) @test_approx_eq D/D2 Diagonal(D.diag./D2.diag) # test triu/tril - @test triu(D,1) == zeros(D) - @test triu(D,0) == D - @test tril(D,1) == zeros(D) - @test tril(D,0) == D + @test istriu(D) + @test istril(D) + @test triu!(D,1) == zeros(D) + @test triu!(D,0) == D + @test tril!(D,1) == zeros(D) + @test tril!(D,0) == D # factorize @test factorize(D) == D