Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Address type stability issues in #12574 and fix a bug or two #12594

Merged
merged 6 commits into from
Aug 14, 2015
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 16 additions & 18 deletions base/linalg/bidiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -108,42 +108,40 @@ ctranspose(M::Bidiagonal) = Bidiagonal(conj(M.dv), conj(M.ev), !M.isupper)
istriu(M::Bidiagonal) = M.isupper || all(M.ev .== 0)
istril(M::Bidiagonal) = !M.isupper || all(M.ev .== 0)

function tril(M::Bidiagonal, k::Integer=0)
function tril!(M::Bidiagonal, k::Integer=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this could use fill! instead of allocating new zeros arrays, right?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I tested this locally and it works. Will wait for CI to finish then update.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think this applies throughout for (most of?) the rest of the types too. Anywhere it's possible to make the ! versions allocation-free and in-place, then I think that would be a good goal. And aim for type-stability but with the smallest amount of type widening that makes sense. The one-arg case could possibly be made to return a more specialized type than the two-arg case for some types since it's more specific in its behavior, but that would require extra methods rather than using default argument values and could be left as a future enhancement.

n = length(M.dv)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif M.isupper && k < 0
return Bidiagonal(zeros(M.dv),zeros(M.ev),M.isupper)
fill!(M.dv,0)
fill!(M.ev,0)
elseif k < -1
return Bidiagonal(zeros(M.dv),zeros(M.ev),M.isupper)
elseif !M.isupper && k == 0
return M
fill!(M.dv,0)
fill!(M.ev,0)
elseif M.isupper && k == 0
return Bidiagonal(M.dv,zeros(M.ev),M.isupper)
fill!(M.ev,0)
elseif !M.isupper && k == -1
return Bidiagonal(zeros(M.dv),M.ev,M.isupper)
elseif k > 0
return M
fill!(M.dv,0)
end
return M
end

function triu(M::Bidiagonal, k::Integer=0)
function triu!(M::Bidiagonal, k::Integer=0)
n = length(M.dv)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif !M.isupper && k > 0
return Bidiagonal(zeros(M.dv),zeros(M.ev),M.isupper)
fill!(M.dv,0)
fill!(M.ev,0)
elseif k > 1
return Bidiagonal(zeros(M.dv),zeros(M.ev),M.isupper)
elseif M.isupper && k == 0
return M
fill!(M.dv,0)
fill!(M.ev,0)
elseif !M.isupper && k == 0
return Bidiagonal(M.dv,zeros(M.ev),M.isupper)
fill!(M.ev,0)
elseif M.isupper && k == 1
return Bidiagonal(zeros(M.dv),M.ev,M.isupper)
elseif k < 0
return M
fill!(M.dv,0)
end
return M
end

function diag{T}(M::Bidiagonal{T}, n::Integer=0)
Expand Down
23 changes: 21 additions & 2 deletions base/linalg/diagonal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -55,8 +55,27 @@ isposdef(D::Diagonal) = all(D.diag .> 0)

factorize(D::Diagonal) = D

tril(D::Diagonal,i::Integer=0) = i == 0 ? D : zeros(D)
triu(D::Diagonal,i::Integer=0) = i == 0 ? D : zeros(D)
istriu(D::Diagonal) = true
istril(D::Diagonal) = true
function triu!(D::Diagonal,k::Integer=0)
n = size(D,1)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif k > 0
fill!(D.diag,0)
end
return D
end

function tril!(D::Diagonal,k::Integer=0)
n = size(D,1)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif k < 0
fill!(D.diag,0)
end
return D
end

==(Da::Diagonal, Db::Diagonal) = Da.diag == Db.diag
-(A::Diagonal)=Diagonal(-A.diag)
Expand Down
2 changes: 2 additions & 0 deletions base/linalg/generic.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,6 +40,8 @@ cross(a::AbstractVector, b::AbstractVector) = [a[2]*b[3]-a[3]*b[2], a[3]*b[1]-a[

triu(M::AbstractMatrix) = triu!(copy(M))
tril(M::AbstractMatrix) = tril!(copy(M))
triu(M::AbstractMatrix,k::Integer) = triu!(copy(M),k)
tril(M::AbstractMatrix,k::Integer) = tril!(copy(M),k)
triu!(M::AbstractMatrix) = triu!(M,0)
tril!(M::AbstractMatrix) = tril!(M,0)

Expand Down
51 changes: 47 additions & 4 deletions base/linalg/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -57,10 +57,53 @@ ctranspose(A::Hermitian) = A
trace(A::Hermitian) = real(trace(A.data))

#tril/triu
tril(A::Hermitian,k::Integer=0) = tril(A.data,k)
triu(A::Hermitian,k::Integer=0) = triu(A.data,k)
tril(A::Symmetric,k::Integer=0) = tril(A.data,k)
triu(A::Symmetric,k::Integer=0) = triu(A.data,k)
function tril(A::Hermitian, k::Integer=0)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This could be Union{Hermitian,Symmetric} right? The implementations look like exact copies.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It passes tests locally with this change. Do we need to run CI again?

if A.uplo == 'U' && k <= 0
return tril!(A.data',k)
elseif A.uplo == 'U' && k > 0
return tril!(A.data',-1) + tril!(triu(A.data),k)
elseif A.uplo == 'L' && k <= 0
return tril(A.data,k)
else
return tril(A.data,-1) + tril!(triu!(A.data'),k)
end
end

function tril(A::Symmetric, k::Integer=0)
if A.uplo == 'U' && k <= 0
return tril!(A.data.',k)
elseif A.uplo == 'U' && k > 0
return tril!(A.data.',-1) + tril!(triu(A.data),k)
elseif A.uplo == 'L' && k <= 0
return tril(A.data,k)
else
return tril(A.data,-1) + tril!(triu!(A.data.'),k)
end
end

function triu(A::Hermitian, k::Integer=0)
if A.uplo == 'U' && k >= 0
return triu(A.data,k)
elseif A.uplo == 'U' && k < 0
return triu(A.data,1) + triu!(tril!(A.data'),k)
elseif A.uplo == 'L' && k >= 0
return triu!(A.data',k)
else
return triu!(A.data',1) + triu!(tril(A.data),k)
end
end

function triu(A::Symmetric, k::Integer=0)
if A.uplo == 'U' && k >= 0
return triu(A.data,k)
elseif A.uplo == 'U' && k < 0
return triu(A.data,1) + triu!(tril!(A.data.'),k)
elseif A.uplo == 'L' && k >= 0
return triu!(A.data.',k)
else
return triu!(A.data.',1) + triu!(tril(A.data),k)
end
end

## Matvec
A_mul_B!{T<:BlasFloat,S<:StridedMatrix}(y::StridedVector{T}, A::Symmetric{T,S}, x::StridedVector{T}) = BLAS.symv!(A.uplo, one(T), A.data, x, zero(T), y)
Expand Down
75 changes: 54 additions & 21 deletions base/linalg/triangular.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,65 +149,98 @@ istril(A::UnitLowerTriangular) = true
istriu(A::UpperTriangular) = true
istriu(A::UnitUpperTriangular) = true

function tril(A::UpperTriangular,k::Integer=0)
function tril!(A::UpperTriangular,k::Integer=0)
n = size(A,1)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif k < 0
return UpperTriangular(zeros(A.data))
fill!(A.data,0)
return A
elseif k == 0
return UpperTriangular(diagm(diag(A)))
for j in 1:n, i in 1:j-1
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Sorry for not suggesting this earlier, but couldn't all of these cases except for the error branch be equivalently stated as UpperTriangular(tril!(A.data,k))? Wouldn't that do the same amount of work and be equivalent even for k <= 0?

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Can we fix the Symmetric methods and leave it there, or do we really want to fix this too?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I think it's all correct and good to merge now. Making it more concise can be left for later. Someone also needs to fix the sub-vs-superdiagonal thing in the docs but that doesn't have to be done here. I'll do it if no one beats me to it and I don't forget about it.

A.data[i,j] = 0
end
return A
else
return UpperTriangular(triu(tril(A.data,k)))
return UpperTriangular(tril!(A.data,k))
end
end
triu!(A::UpperTriangular,k::Integer=0) = UpperTriangular(triu!(A.data,k))

triu(A::UpperTriangular,k::Integer=0) = UpperTriangular(triu(triu(A.data),k))

function tril(A::UnitUpperTriangular,k::Integer=0)
function tril!(A::UnitUpperTriangular,k::Integer=0)
n = size(A,1)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif k < 0
return UnitUpperTriangular(zeros(A.data))
fill!(A.data,0)
return UpperTriangular(A.data)
elseif k == 0
return UnitUpperTriangular(eye(A))
fill!(A.data,0)
for i in diagind(A)
A.data[i] = one(eltype(A))
end
return UpperTriangular(A.data)
else
return UnitUpperTriangular(triu(tril(A.data,k)))
for i in diagind(A)
A.data[i] = one(eltype(A))
end
return UpperTriangular(tril!(A.data,k))
end
end

triu(A::UnitUpperTriangular,k::Integer=0) = UnitUpperTriangular(triu(triu(A.data),k))
function triu!(A::UnitUpperTriangular,k::Integer=0)
for i in diagind(A)
A.data[i] = one(eltype(A))
end
return triu!(UpperTriangular(A.data),k)
end

function triu(A::LowerTriangular,k::Integer=0)
function triu!(A::LowerTriangular,k::Integer=0)
n = size(A,1)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif k > 0
return LowerTriangular(zeros(A.data))
fill!(A.data,0)
return A
elseif k == 0
return LowerTriangular(diagm(diag(A)))
for j in 1:n, i in j+1:n
A.data[i,j] = 0
end
return A
else
return LowerTriangular(tril(triu(A.data,k)))
return LowerTriangular(triu!(A.data,k))
end
end

tril(A::LowerTriangular,k::Integer=0) = LowerTriangular(tril(tril(A.data),k))
tril!(A::LowerTriangular,k::Integer=0) = LowerTriangular(tril!(A.data,k))

function triu(A::UnitLowerTriangular,k::Integer=0)
function triu!(A::UnitLowerTriangular,k::Integer=0)
n = size(A,1)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif k > 0
return UnitLowerTriangular(zeros(A.data))
fill!(A.data,0)
return LowerTriangular(A.data)
elseif k == 0
return UnitLowerTriangular(eye(A))
fill!(A.data,0)
for i in diagind(A)
A.data[i] = one(eltype(A))
end
return LowerTriangular(A.data)
else
return UnitLowerTriangular(tril(triu(A.data,k)))
for i in diagind(A)
A.data[i] = one(eltype(A))
end
return LowerTriangular(triu!(A.data,k))
end
end

tril(A::UnitLowerTriangular,k::Integer=0) = UnitLowerTriangular(tril(tril(A.data),k))
function tril!(A::UnitLowerTriangular,k::Integer=0)
for i in diagind(A)
A.data[i] = one(eltype(A))
end
return tril!(LowerTriangular(A.data),k)
end

transpose{T,S}(A::LowerTriangular{T,S}) = UpperTriangular{T, S}(transpose(A.data))
transpose{T,S}(A::UnitLowerTriangular{T,S}) = UnitUpperTriangular{T, S}(transpose(A.data))
Expand Down
54 changes: 32 additions & 22 deletions base/linalg/tridiag.jl
Original file line number Diff line number Diff line change
Expand Up @@ -149,33 +149,39 @@ eigvecs{T<:BlasFloat,Eigenvalue<:Real}(A::SymTridiagonal{T}, eigvals::Vector{Eig
istriu(M::SymTridiagonal) = all(M.ev .== 0)
istril(M::SymTridiagonal) = all(M.ev .== 0)

function tril(M::SymTridiagonal, k::Integer=0)
function tril!(M::SymTridiagonal, k::Integer=0)
n = length(M.dv)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif k < -1
return SymTridiagonal(zeros(M.dv),zeros(M.ev))
fill!(M.ev,0)
fill!(M.dv,0)
return Tridiagonal(M.ev,M.dv,copy(M.ev))
elseif k == -1
return Tridiagonal(M.ev,zeros(M.dv),zeros(M.ev))
fill!(M.dv,0)
return Tridiagonal(M.ev,M.dv,zeros(M.ev))
elseif k == 0
return Tridiagonal(M.ev,M.dv,zeros(M.ev))
elseif k >= 1
return M
return Tridiagonal(M.ev,M.dv,copy(M.ev))
end
end

function triu(M::SymTridiagonal, k::Integer=0)
function triu!(M::SymTridiagonal, k::Integer=0)
n = length(M.dv)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif k > 1
return SymTridiagonal(zeros(M.dv),zeros(M.ev))
fill!(M.ev,0)
fill!(M.dv,0)
return Tridiagonal(M.ev,M.dv,copy(M.ev))
elseif k == 1
return Tridiagonal(zeros(M.ev),zeros(M.dv),M.ev)
fill!(M.dv,0)
return Tridiagonal(zeros(M.ev),M.dv,M.ev)
elseif k == 0
return Tridiagonal(zeros(M.ev),M.dv,M.ev)
elseif k <= -1
return M
return Tridiagonal(M.ev,M.dv,copy(M.ev))
end
end

Expand Down Expand Up @@ -356,37 +362,41 @@ end

#tril and triu

istriu(M::Tridiagonal) = all(M.ev .== 0)
istril(M::Tridiagonal) = all(M.ev .== 0)
istriu(M::Tridiagonal) = all(M.dl .== 0)
istril(M::Tridiagonal) = all(M.du .== 0)

function tril(M::Tridiagonal, k::Integer=0)
function tril!(M::Tridiagonal, k::Integer=0)
n = length(M.d)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif k < -1
return Tridiagonal(zeros(M.dl),zeros(M.d),zeros(M.du))
fill!(M.dl,0)
fill!(M.d,0)
fill!(M.du,0)
elseif k == -1
return Tridiagonal(M.dl,zeros(M.d),zeros(M.du))
fill!(M.d,0)
fill!(M.du,0)
elseif k == 0
return Tridiagonal(M.dl,M.d,zeros(M.du))
elseif k >= 1
return M
fill!(M.du,0)
end
return M
end

function triu(M::Tridiagonal, k::Integer=0)
function triu!(M::Tridiagonal, k::Integer=0)
n = length(M.d)
if abs(k) > n
throw(ArgumentError("requested diagonal, $k, out of bounds in matrix of size ($n,$n)"))
elseif k > 1
return Tridiagonal(zeros(M.dl),zeros(M.d),zeros(M.du))
fill!(M.dl,0)
fill!(M.d,0)
fill!(M.du,0)
elseif k == 1
return Tridiagonal(zeros(M.dl),zeros(M.d),M.du)
fill!(M.dl,0)
fill!(M.d,0)
elseif k == 0
return Tridiagonal(zeros(M.dl),M.d,M.du)
elseif k <= -1
return M
fill!(M.dl,0)
end
return M
end

###################
Expand Down
Loading