From a21980b632f7d362d8974e694cd756fb1abc9099 Mon Sep 17 00:00:00 2001 From: Massimiliano Fasi Date: Tue, 4 Aug 2015 18:10:23 +0100 Subject: [PATCH] Add macros to generate matrix functions for symmetric and Hermitian matrices, fix type returned by logm and sqrtm --- base/linalg/dense.jl | 14 ++++++--- base/linalg/eigen.jl | 2 +- base/linalg/symmetric.jl | 61 ++++++++++++++++++++++++++++++++++------ 3 files changed, 63 insertions(+), 14 deletions(-) diff --git a/base/linalg/dense.jl b/base/linalg/dense.jl index eaf69aa48aa49c..844ef04ded02bc 100644 --- a/base/linalg/dense.jl +++ b/base/linalg/dense.jl @@ -189,7 +189,9 @@ expm(x::Number) = exp(x) ## "Functions of Matrices: Theory and Computation", SIAM function expm!{T<:BlasFloat}(A::StridedMatrix{T}) n = chksquare(A) - n<2 && return exp(A) + if ishermitian(A) + return full(expm(Hermitian(A))) + end ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A nA = norm(A, 1) I = eye(T,n) @@ -278,10 +280,12 @@ function rcswap!{T<:Number}(i::Integer, j::Integer, X::StridedMatrix{T}) end end +expm(x::Number) = exp(x) + function logm(A::StridedMatrix) # If possible, use diagonalization if ishermitian(A) - return logm(Hermitian(A)) + return full(logm(Hermitian(A))) end # Use Schur decomposition @@ -313,12 +317,13 @@ function logm(A::StridedMatrix) return retmat end end + logm(a::Number) = (b = log(complex(a)); imag(b) == 0 ? real(b) : b) logm(a::Complex) = log(a) function sqrtm{T<:Real}(A::StridedMatrix{T}) if issym(A) - return sqrtm(Symmetric(A)) + return full(sqrtm(Symmetric(A))) end n = chksquare(A) SchurF = schurfact(complex(A)) @@ -328,13 +333,14 @@ function sqrtm{T<:Real}(A::StridedMatrix{T}) end function sqrtm{T<:Complex}(A::StridedMatrix{T}) if ishermitian(A) - return sqrtm(Hermitian(A)) + return full(sqrtm(Hermitian(A))) end n = chksquare(A) SchurF = schurfact(A) R = full(sqrtm(UpperTriangular(SchurF[:T]))) SchurF[:vectors]*R*SchurF[:vectors]' end + sqrtm(a::Number) = (b = sqrt(complex(a)); imag(b) == 0 ? real(b) : b) sqrtm(a::Complex) = sqrt(a) diff --git a/base/linalg/eigen.jl b/base/linalg/eigen.jl index c8ba22927ac5d4..74bb4f3daff7e5 100644 --- a/base/linalg/eigen.jl +++ b/base/linalg/eigen.jl @@ -23,7 +23,7 @@ function getindex(A::Union{Eigen,GeneralizedEigen}, d::Symbol) throw(KeyError(d)) end -isposdef(A::Union{Eigen,GeneralizedEigen}) = all(A.values .> 0) +isposdef(A::Union{Eigen,GeneralizedEigen}) = isreal(A.values) && all(A.values .> 0) function eigfact!{T<:BlasReal}(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true) n = size(A, 2) diff --git a/base/linalg/symmetric.jl b/base/linalg/symmetric.jl index a5877202effb1c..f507ee4cbb2418 100644 --- a/base/linalg/symmetric.jl +++ b/base/linalg/symmetric.jl @@ -126,14 +126,57 @@ function svdvals!{T<:Real,S}(A::Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{ end #Matrix-valued functions -expm{T<:Real}(A::RealHermSymComplexHerm{T}) = (F = eigfact(A); F.vectors*Diagonal(exp(F.values))*F.vectors') -function logm{T<:Real}(A::RealHermSymComplexHerm{T}) - F = eigfact(A) - isposdef(F) && return F.vectors*Diagonal(log(F.values))*F.vectors' - return F.vectors*Diagonal(log(complex(F.values)))*F.vectors' +function expm(A::Symmetric) + F = eigfact(full(A)) + return Symmetric(F.vectors * Diagonal(exp(F.values)) * F.vectors') end -function sqrtm{T<:Real}(A::RealHermSymComplexHerm{T}) - F = eigfact(A) - isposdef(F) && return F.vectors*Diagonal(sqrt(F.values))*F.vectors' - return F.vectors*Diagonal(sqrt(complex(F.values)))*F.vectors' +function expm{T}(A::Hermitian{T}) + n = chksquare(A) + F = eigfact(full(A)) + retmat = F.vectors * Diagonal(exp(F.values)) * F.vectors' + if T <: Real + return real(Hermitian(retmat)) + else + for i = 1:n + retmat[i,i] = real(retmat[i,i]) + end + return Hermitian(retmat) + end +end + +for (funm, func) in ([:logm,:log], [:sqrtm,:sqrt]) + + @eval begin + + function ($funm)(A::Symmetric) + F = eigfact(full(A)) + if isposdef(F) + retmat = F.vectors * Diagonal(($func)(F.values)) * F.vectors' + else + retmat = F.vectors * Diagonal(($func)(complex(F.values))) * F.vectors' + end + return Symmetric(retmat) + end + + function ($funm){T}(A::Hermitian{T}) + n = chksquare(A) + F = eigfact(A) + if isposdef(F) + retmat = F.vectors * Diagonal(($func)(F.values)) * F.vectors' + if T <: Real + return real(Hermitian(retmat)) + else + for i = 1:n + retmat[i,i] = real(retmat[i,i]) + end + return Hermitian(retmat) + end + else + retmat = F.vectors * Diagonal(($func)(complex(F.values))) * F.vectors' + return retmat + end + end + + end + end