From a67c4e275dc03bcb809877a3240f9128c4748421 Mon Sep 17 00:00:00 2001 From: kshyatt Date: Thu, 13 Aug 2015 15:02:00 -0700 Subject: [PATCH] Fix residual issues and add fallback methods --- base/linalg/diagonal.jl | 6 ++---- base/linalg/generic.jl | 2 ++ base/linalg/triangular.jl | 19 ++++++++----------- base/linalg/tridiag.jl | 8 ++------ test/linalg/diagonal.jl | 10 ++++++---- 5 files changed, 20 insertions(+), 25 deletions(-) diff --git a/base/linalg/diagonal.jl b/base/linalg/diagonal.jl index a9ac1c4de9e283..b4dd70c57f037f 100644 --- a/base/linalg/diagonal.jl +++ b/base/linalg/diagonal.jl @@ -61,23 +61,21 @@ function triu!(D::Diagonal,k::Integer=0) n = size(D,1) if abs(k) > n throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)")) - elseif k != 0 + elseif k > 0 fill!(D.diag,0) end return D end -triu(D::Diagonal,k::Integer=0) = triu!(copy(D),k) function tril!(D::Diagonal,k::Integer=0) n = size(D,1) if abs(k) > n throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)")) - elseif k != 0 + elseif k < 0 fill!(D.diag,0) end return D end -tril(D::Diagonal,k::Integer=0) = tril!(copy(D),k) ==(Da::Diagonal, Db::Diagonal) = Da.diag == Db.diag -(A::Diagonal)=Diagonal(-A.diag) diff --git a/base/linalg/generic.jl b/base/linalg/generic.jl index d7b76cb77b363b..33ff90cd07391b 100644 --- a/base/linalg/generic.jl +++ b/base/linalg/generic.jl @@ -40,6 +40,8 @@ cross(a::AbstractVector, b::AbstractVector) = [a[2]*b[3]-a[3]*b[2], a[3]*b[1]-a[ triu(M::AbstractMatrix) = triu!(copy(M)) tril(M::AbstractMatrix) = tril!(copy(M)) +triu(M::AbstractMatrix,k::Integer) = triu!(copy(M),k) +tril(M::AbstractMatrix,k::Integer) = tril!(copy(M),k) triu!(M::AbstractMatrix) = triu!(M,0) tril!(M::AbstractMatrix) = tril!(M,0) diff --git a/base/linalg/triangular.jl b/base/linalg/triangular.jl index a185c0ea3ac56d..5eda17721192d1 100644 --- a/base/linalg/triangular.jl +++ b/base/linalg/triangular.jl @@ -157,7 +157,10 @@ function tril!(A::UpperTriangular,k::Integer=0) fill!(A.data,0) return A elseif k == 0 - return UpperTriangular(diagm(diag(A))) + for j in 1:n, i in 1:j-1 + A.data[i,j] = 0 + end + return A else return UpperTriangular(tril!(A.data,k)) end @@ -200,7 +203,10 @@ function triu!(A::LowerTriangular,k::Integer=0) fill!(A.data,0) return A elseif k == 0 - return LowerTriangular(diagm(diag(A))) + for j in 1:n, i in j+1:n + A.data[i,j] = 0 + end + return A else return LowerTriangular(triu!(A.data,k)) end @@ -236,15 +242,6 @@ function tril!(A::UnitLowerTriangular,k::Integer=0) return tril!(LowerTriangular(A.data),k) end -tril(A::UpperTriangular,k::Integer=0) = tril!(copy(A),k) -tril(A::LowerTriangular,k::Integer=0) = tril!(copy(A),k) -tril(A::UnitUpperTriangular,k::Integer=0) = tril!(copy(A),k) -tril(A::UnitLowerTriangular,k::Integer=0) = tril!(copy(A),k) -triu(A::UpperTriangular,k::Integer=0) = triu!(copy(A),k) -triu(A::LowerTriangular,k::Integer=0) = triu!(copy(A),k) -triu(A::UnitUpperTriangular,k::Integer=0) = triu!(copy(A),k) -triu(A::UnitLowerTriangular,k::Integer=0) = triu!(copy(A),k) - transpose{T,S}(A::LowerTriangular{T,S}) = UpperTriangular{T, S}(transpose(A.data)) transpose{T,S}(A::UnitLowerTriangular{T,S}) = UnitUpperTriangular{T, S}(transpose(A.data)) transpose{T,S}(A::UpperTriangular{T,S}) = LowerTriangular{T, S}(transpose(A.data)) diff --git a/base/linalg/tridiag.jl b/base/linalg/tridiag.jl index ddc82e9ee9b2b5..6aef654acdc37e 100644 --- a/base/linalg/tridiag.jl +++ b/base/linalg/tridiag.jl @@ -156,7 +156,7 @@ function tril!(M::SymTridiagonal, k::Integer=0) elseif k < -1 fill!(M.ev,0) fill!(M.dv,0) - return Tridiagonal(M.ev,M.dv,M.ev) + return Tridiagonal(M.ev,M.dv,copy(M.ev)) elseif k == -1 fill!(M.dv,0) return Tridiagonal(M.ev,M.dv,zeros(M.ev)) @@ -166,7 +166,6 @@ function tril!(M::SymTridiagonal, k::Integer=0) return Tridiagonal(M.ev,M.dv,copy(M.ev)) end end -tril(M::SymTridiagonal, k::Integer=0) = tril!(copy(M),k) function triu!(M::SymTridiagonal, k::Integer=0) n = length(M.dv) @@ -175,7 +174,7 @@ function triu!(M::SymTridiagonal, k::Integer=0) elseif k > 1 fill!(M.ev,0) fill!(M.dv,0) - return Tridiagonal(M.ev,M.dv,M.ev) + return Tridiagonal(M.ev,M.dv,copy(M.ev)) elseif k == 1 fill!(M.dv,0) return Tridiagonal(zeros(M.ev),M.dv,M.ev) @@ -185,7 +184,6 @@ function triu!(M::SymTridiagonal, k::Integer=0) return Tridiagonal(M.ev,M.dv,copy(M.ev)) end end -triu(M::SymTridiagonal, k::Integer=0) = triu!(copy(M),k) ################### # Generic methods # @@ -383,7 +381,6 @@ function tril!(M::Tridiagonal, k::Integer=0) end return M end -tril(M::Tridiagonal, k::Integer=0) = tril!(copy(M),k) function triu!(M::Tridiagonal, k::Integer=0) n = length(M.d) @@ -401,7 +398,6 @@ function triu!(M::Tridiagonal, k::Integer=0) end return M end -triu(M::Tridiagonal, k::Integer=0) = triu!(copy(M),k) ################### # Generic methods # diff --git a/test/linalg/diagonal.jl b/test/linalg/diagonal.jl index 7e9e4db1b79412..a7ca0d619771a9 100644 --- a/test/linalg/diagonal.jl +++ b/test/linalg/diagonal.jl @@ -86,10 +86,12 @@ for relty in (Float32, Float64, BigFloat), elty in (relty, Complex{relty}) # test triu/tril @test istriu(D) @test istril(D) - @test triu(D,1) == zeros(D) - @test triu(D,0) == D - @test tril(D,1) == zeros(D) - @test tril(D,0) == D + @test triu(D,1) == zeros(D) + @test triu(D,0) == D + @test triu(D,-1) == D + @test tril(D,1) == D + @test tril(D,-1) == zeros(D) + @test tril(D,0) == D # factorize @test factorize(D) == D