Skip to content

Commit

Permalink
Factorization is now type stable
Browse files Browse the repository at this point in the history
Previously retured types depended on the boolean value of a keyword argument `pivot`. Now the api relies on dispatch on different tyeps either `pivot=Val{true}` or `pivot=Val{false}`. A few more type instabilities were mitigated by using a `copy_oftype` method.
  • Loading branch information
skariel committed Jan 27, 2015
1 parent 8b4e9e9 commit f9da849
Show file tree
Hide file tree
Showing 10 changed files with 95 additions and 72 deletions.
2 changes: 2 additions & 0 deletions NEWS.md
Original file line number Diff line number Diff line change
Expand Up @@ -76,6 +76,8 @@ Compiler improvements
Library improvements
--------------------

* Factorization api is now type-stable, functions dispatch on `Val{false}` or `Val{true}` instead of a boolean value ([#9575]).

* `convert` now checks for overflow when truncating integers or converting between
signed and unsigned ([#5413]).

Expand Down
40 changes: 27 additions & 13 deletions base/linalg/cholesky.jl
Original file line number Diff line number Diff line change
Expand Up @@ -84,14 +84,17 @@ function chol!{T}(A::AbstractMatrix{T}, uplo::Symbol)
return uplo == :U ? UpperTriangular(A) : LowerTriangular(A)
end

function cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0)
cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}; tol=0.0) =
_cholfact!(A, pivot, uplo, tol=tol)
function _cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, ::Type{Val{false}}, uplo::Symbol=:U; tol=0.0)
uplochar = char_uplo(uplo)
if pivot
A, piv, rank, info = LAPACK.pstrf!(uplochar, A, tol)
return CholeskyPivoted{T,typeof(A)}(A, uplochar, piv, rank, tol, info)
end
return Cholesky(chol!(A, uplo).data, uplo)
end
function _cholfact!{T<:BlasFloat}(A::StridedMatrix{T}, ::Type{Val{true}}, uplo::Symbol=:U; tol=0.0)
uplochar = char_uplo(uplo)
A, piv, rank, info = LAPACK.pstrf!(uplochar, A, tol)
return CholeskyPivoted{T,StridedMatrix{T}}(A, uplochar, piv, rank, tol, info)
end
cholfact!(A::AbstractMatrix, uplo::Symbol=:U) = Cholesky(chol!(A, uplo).data, uplo)

function cholfact!{T<:BlasFloat,S,UpLo}(C::Cholesky{T,S,UpLo})
Expand All @@ -100,14 +103,25 @@ function cholfact!{T<:BlasFloat,S,UpLo}(C::Cholesky{T,S,UpLo})
C
end

cholfact{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0) = cholfact!(copy(A), uplo, pivot=pivot, tol=tol)
function cholfact{T}(A::StridedMatrix{T}, uplo::Symbol=:U; pivot=false, tol=0.0)
S = promote_type(typeof(chol(one(T))),Float32)
S <: BlasFloat && return cholfact!(convert(AbstractMatrix{S}, A), uplo, pivot = pivot, tol = tol)
pivot && throw(ArgumentError("pivot only supported for Float32, Float64, Complex{Float32} and Complex{Float64}"))
S != T && return cholfact!(convert(AbstractMatrix{S}, A), uplo)
return cholfact!(copy(A), uplo)
end
cholfact{T<:BlasFloat}(A::StridedMatrix{T}, uplo::Symbol=:U, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}; tol=0.0) =
cholfact!(copy(A), uplo, pivot, tol=tol)


copy_oftype{T}(A::StridedMatrix{T}, ::Type{T}) = copy(A)
copy_oftype{T,S}(A::StridedMatrix{T}, ::Type{S}) = convert(AbstractMatrix{S}, A)
cholfact{T}(A::StridedMatrix{T}, uplo::Symbol=:U, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}; tol=0.0) =
_cholfact(copy_oftype(A, promote_type(typeof(chol(one(T))),Float32)), pivot, uplo, tol=tol)
_cholfact{T<:BlasFloat}(A::StridedMatrix{T}, pivot::Type{Val{true}}, uplo::Symbol=:U; tol=0.0) =
cholfact!(A, uplo, pivot, tol = tol)
_cholfact{T<:BlasFloat}(A::StridedMatrix{T}, pivot::Type{Val{false}}, uplo::Symbol=:U; tol=0.0) =
cholfact!(A, uplo, pivot, tol = tol)

_cholfact{T}(A::StridedMatrix{T}, ::Type{Val{false}}, uplo::Symbol=:U; tol=0.0) =
cholfact!(A, uplo)
_cholfact{T}(A::StridedMatrix{T}, ::Type{Val{true}}, uplo::Symbol=:U; tol=0.0) =
throw(ArgumentError("pivot only supported for Float32, Float64, Complex{Float32} and Complex{Float64}"))


function cholfact(x::Number, uplo::Symbol=:U)
xf = fill(chol!(x, uplo), 1, 1)
Cholesky(xf, uplo)
Expand Down
21 changes: 12 additions & 9 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -369,20 +369,23 @@ function factorize{T}(A::Matrix{T})
end
return lufact(A)
end
qrfact(A,pivot=typeof(zero(T)/sqrt(zero(T) + zero(T)))<:BlasFloat) # Generic pivoted QR not implemented yet
qrfact(A,typeof(zero(T)/sqrt(zero(T) + zero(T)))<:BlasFloat?Val{true}:Val{false}) # Generic pivoted QR not implemented yet
end

(\)(a::Vector, B::StridedVecOrMat) = (\)(reshape(a, length(a), 1), B)
function (\)(A::StridedMatrix, B::StridedVecOrMat)
m, n = size(A)
if m == n
if istril(A)
return istriu(A) ? \(Diagonal(A),B) : \(LowerTriangular(A),B)

for (T1,PIVOT) in ((BlasFloat,Val{true}),(Any,Val{false}))
@eval function (\){T<:$T1}(A::StridedMatrix{T}, B::StridedVecOrMat)
m, n = size(A)
if m == n
if istril(A)
return istriu(A) ? \(Diagonal(A),B) : \(LowerTriangular(A),B)
end
istriu(A) && return \(UpperTriangular(A),B)
return \(lufact(A),B)
end
istriu(A) && return \(UpperTriangular(A),B)
return \(lufact(A),B)
return qrfact(A,$PIVOT)\B
end
return qrfact(A,pivot=eltype(A)<:BlasFloat)\B
end

## Moore-Penrose inverse
Expand Down
28 changes: 16 additions & 12 deletions base/linalg/factorization.jl
Original file line number Diff line number Diff line change
Expand Up @@ -40,9 +40,8 @@ immutable QRPivoted{T,S<:AbstractMatrix} <: Factorization{T}
end
QRPivoted{T}(factors::AbstractMatrix{T}, τ::Vector{T}, jpvt::Vector{BlasInt}) = QRPivoted{T,typeof(factors)}(factors, τ, jpvt)

qrfact!{T<:BlasFloat}(A::StridedMatrix{T}; pivot=false) = pivot ? QRPivoted(LAPACK.geqp3!(A)...) : QRCompactWY(LAPACK.geqrt!(A, min(minimum(size(A)), 36))...)
function qrfact!{T}(A::AbstractMatrix{T}; pivot=false)
pivot && warn("pivoting only implemented for Float32, Float64, Complex64 and Complex128")
function qrfact!{T}(A::AbstractMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false})
pivot==Val{true} && warn("pivoting only implemented for Float32, Float64, Complex64 and Complex128")
m, n = size(A)
τ = zeros(T, min(m,n))
@inbounds begin
Expand All @@ -64,17 +63,22 @@ function qrfact!{T}(A::AbstractMatrix{T}; pivot=false)
end
QR(A, τ)
end
qrfact{T<:BlasFloat}(A::StridedMatrix{T}; pivot=false) = qrfact!(copy(A),pivot=pivot)
qrfact{T}(A::StridedMatrix{T}; pivot=false) = (S = typeof(one(T)/norm(one(T)));S != T ? qrfact!(convert(AbstractMatrix{S},A), pivot=pivot) : qrfact!(copy(A),pivot=pivot))
qrfact!{T<:BlasFloat}(A::StridedMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}) = pivot==Val{true} ? QRPivoted(LAPACK.geqp3!(A)...) : QRCompactWY(LAPACK.geqrt!(A, min(minimum(size(A)), 36))...)
qrfact{T<:BlasFloat}(A::StridedMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}) = qrfact!(copy(A), pivot)
copy_oftype{T}(A::StridedMatrix{T}, ::Type{T}) = copy(A)
copy_oftype{T,S}(A::StridedMatrix{T}, ::Type{S}) = convert(AbstractMatrix{S}, A)
qrfact{T}(A::StridedMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}) = qrfact!(copy_oftype(A, typeof(one(T)/norm(one(T)))), pivot)
qrfact(x::Number) = qrfact(fill(x,1,1))

function qr(A::Union(Number, AbstractMatrix); pivot::Bool=false, thin::Bool=true)
F = qrfact(A, pivot=pivot)
if pivot
full(F[:Q], thin=thin), F[:R], F[:p]
else
full(F[:Q], thin=thin), F[:R]
end
qr(A::Union(Number, AbstractMatrix), pivot::Union(Type{Val{false}}, Type{Val{true}})=Val{false}; thin::Bool=true) =
_qr(A, pivot, thin=thin)
function _qr(A::Union(Number, AbstractMatrix), ::Type{Val{false}}; thin::Bool=true)
F = qrfact(A, Val{false})
full(F[:Q], thin=thin), F[:R]
end
function _qr(A::Union(Number, AbstractMatrix), ::Type{Val{true}}; thin::Bool=true)
F = qrfact(A, Val{true})
full(F[:Q], thin=thin), F[:R], F[:p]
end

convert{T}(::Type{QR{T}},A::QR) = QR(convert(AbstractMatrix{T}, A.factors), convert(Vector{T}, A.τ))
Expand Down
24 changes: 12 additions & 12 deletions base/linalg/lu.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ end
LU{T}(factors::AbstractMatrix{T}, ipiv::Vector{BlasInt}, info::BlasInt) = LU{T,typeof(factors)}(factors, ipiv, info)

# StridedMatrix
function lufact!{T<:BlasFloat}(A::StridedMatrix{T}; pivot = true)
!pivot && return generic_lufact!(A, pivot=pivot)
function lufact!{T<:BlasFloat}(A::StridedMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true})
pivot==Val{false} && return generic_lufact!(A, pivot)
lpt = LAPACK.getrf!(A)
return LU{T,typeof(A)}(lpt[1], lpt[2], lpt[3])
end
lufact!(A::StridedMatrix; pivot = true) = generic_lufact!(A, pivot=pivot)
function generic_lufact!{T}(A::StridedMatrix{T}; pivot = true)
lufact!(A::StridedMatrix, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true}) = generic_lufact!(A, pivot)
function generic_lufact!{T}(A::StridedMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true})
m, n = size(A)
minmn = min(m,n)
info = 0
Expand All @@ -25,7 +25,7 @@ function generic_lufact!{T}(A::StridedMatrix{T}; pivot = true)
for k = 1:minmn
# find index max
kp = k
if pivot
if pivot==Val{true}
amax = real(zero(T))
for i = k:m
absi = abs(A[i,k])
Expand Down Expand Up @@ -63,14 +63,14 @@ function generic_lufact!{T}(A::StridedMatrix{T}; pivot = true)
end
LU{T,typeof(A)}(A, ipiv, convert(BlasInt, info))
end
lufact{T<:BlasFloat}(A::AbstractMatrix{T}; pivot = true) = lufact!(copy(A), pivot=pivot)
lufact{T}(A::AbstractMatrix{T}; pivot = true) = (S = typeof(zero(T)/one(T)); S != T ? lufact!(convert(AbstractMatrix{S}, A), pivot=pivot) : lufact!(copy(A), pivot=pivot))
lufact{T<:BlasFloat}(A::AbstractMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true}) = lufact!(copy(A), pivot)
lufact{T}(A::AbstractMatrix{T}, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true}) = (S = typeof(zero(T)/one(T)); S != T ? lufact!(convert(AbstractMatrix{S}, A), pivot) : lufact!(copy(A), pivot))
lufact(x::Number) = LU(fill(x, 1, 1), BlasInt[1], x == 0 ? one(BlasInt) : zero(BlasInt))
lufact(F::LU) = F

lu(x::Number) = (one(x), x, 1)
function lu(A::AbstractMatrix; pivot = true)
F = lufact(A, pivot = pivot)
function lu(A::AbstractMatrix, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true})
F = lufact(A, pivot)
F[:L], F[:U], F[:p]
end

Expand Down Expand Up @@ -156,7 +156,7 @@ cond(A::LU, p::Number) = norm(A[:L]*A[:U],p)*norm(inv(A),p)
# Tridiagonal

# See dgttrf.f
function lufact!{T}(A::Tridiagonal{T}; pivot = true)
function lufact!{T}(A::Tridiagonal{T}, pivot::Union(Type{Val{false}}, Type{Val{true}}) = Val{true})
n = size(A, 1)
info = 0
ipiv = Array(BlasInt, n)
Expand All @@ -171,7 +171,7 @@ function lufact!{T}(A::Tridiagonal{T}; pivot = true)
end
for i = 1:n-2
# pivot or not?
if !pivot || abs(d[i]) >= abs(dl[i])
if pivot==Val{false} || abs(d[i]) >= abs(dl[i])
# No interchange
if d[i] != 0
fact = dl[i]/d[i]
Expand All @@ -194,7 +194,7 @@ function lufact!{T}(A::Tridiagonal{T}; pivot = true)
end
if n > 1
i = n-1
if !pivot || abs(d[i]) >= abs(dl[i])
if pivot==Val{false} || abs(d[i]) >= abs(dl[i])
if d[i] != 0
fact = dl[i]/d[i]
dl[i] = fact
Expand Down
22 changes: 11 additions & 11 deletions doc/helpdb.jl
Original file line number Diff line number Diff line change
Expand Up @@ -6459,7 +6459,7 @@ popdisplay(d::Display)
"),

("Base","lufact","lufact(A[, pivot=true]) -> F
("Base","lufact","lufact(A[, pivot=Val{true}]) -> F
Compute the LU factorization of \"A\". The return type of \"F\"
depends on the type of \"A\". In most cases, if \"A\" is a subtype
Expand Down Expand Up @@ -6536,18 +6536,18 @@ popdisplay(d::Display)
"),

("Base","cholfact","cholfact(A, [LU,][pivot=false,][tol=-1.0]) -> Cholesky
("Base","cholfact","cholfact(A [,LU=:U [,pivot=Val{false}]][;tol=-1.0]) -> Cholesky
Compute the Cholesky factorization of a dense symmetric positive
(semi)definite matrix \"A\" and return either a \"Cholesky\" if
\"pivot=false\" or \"CholeskyPivoted\" if \"pivot=true\". \"LU\"
\"pivot==Val{false}\" or \"CholeskyPivoted\" if \"pivot==Val{true}\". \"LU\"
may be \":L\" for using the lower part or \":U\" for the upper
part. The default is to use \":U\". The triangular matrix can be
obtained from the factorization \"F\" with: \"F[:L]\" and
\"F[:U]\". The following functions are available for \"Cholesky\"
objects: \"size\", \"\\\", \"inv\", \"det\". For
\"CholeskyPivoted\" there is also defined a \"rank\". If
\"pivot=false\" a \"PosDefException\" exception is thrown in case
\"pivot==Val{false}\" a \"PosDefException\" exception is thrown in case
the matrix is not positive definite. The argument \"tol\"
determines the tolerance for determining the rank. For negative
values, the tolerance is the machine precision.
Expand All @@ -6574,7 +6574,7 @@ popdisplay(d::Display)
"),

("Base","cholfact!","cholfact!(A, [LU,][pivot=false,][tol=-1.0]) -> Cholesky
("Base","cholfact!","cholfact!(A [,LU=:U,[pivot=Val{false}]][;tol=-1.0]) -> Cholesky
\"cholfact!\" is the same as \"cholfact()\", but saves space by
overwriting the input \"A\", instead of creating a copy.
Expand All @@ -6592,7 +6592,7 @@ popdisplay(d::Display)
"),

("Base","qr","qr(A, [pivot=false,][thin=true]) -> Q, R, [p]
("Base","qr","qr(A [,pivot=Val{false}][;thin=true]) -> Q, R, [p]
Compute the (pivoted) QR factorization of \"A\" such that either
\"A = Q*R\" or \"A[:,p] = Q*R\". Also see \"qrfact\". The default
Expand All @@ -6601,20 +6601,20 @@ popdisplay(d::Display)
"),

("Base","qrfact","qrfact(A[, pivot=false]) -> F
("Base","qrfact","qrfact(A[, pivot=Val{false}]) -> F
Computes the QR factorization of \"A\". The return type of \"F\"
depends on the element type of \"A\" and whether pivoting is
specified (with \"pivot=true\").
specified (with \"pivot==Val{true}\").
+------------------+-------------------+-----------+---------------------------------------+
| Return type | \\\"eltype(A)\\\" | \\\"pivot\\\" | Relationship between \\\"F\\\" and \\\"A\\\" |
+------------------+-------------------+-----------+---------------------------------------+
| \\\"QR\\\" | not \\\"BlasFloat\\\" | either | \\\"A==F[:Q]*F[:R]\\\" |
+------------------+-------------------+-----------+---------------------------------------+
| \\\"QRCompactWY\\\" | \\\"BlasFloat\\\" | \\\"false\\\" | \\\"A==F[:Q]*F[:R]\\\" |
| \\\"QRCompactWY\\\" | \\\"BlasFloat\\\" | \\\"Val{false}\\\" | \\\"A==F[:Q]*F[:R]\\\" |
+------------------+-------------------+-----------+---------------------------------------+
| \\\"QRPivoted\\\" | \\\"BlasFloat\\\" | \\\"true\\\" | \\\"A[:,F[:p]]==F[:Q]*F[:R]\\\" |
| \\\"QRPivoted\\\" | \\\"BlasFloat\\\" | \\\"Val{true}\\\" | \\\"A[:,F[:p]]==F[:Q]*F[:R]\\\" |
+------------------+-------------------+-----------+---------------------------------------+
\"BlasFloat\" refers to any of: \"Float32\", \"Float64\",
Expand Down Expand Up @@ -6681,7 +6681,7 @@ popdisplay(d::Display)
"),

("Base","qrfact!","qrfact!(A[, pivot=false])
("Base","qrfact!","qrfact!(A[, pivot=Val{false}])
\"qrfact!\" is the same as \"qrfact()\", but saves space by
overwriting the input \"A\", instead of creating a copy.
Expand Down
Loading

0 comments on commit f9da849

Please sign in to comment.