Skip to content

Commit

Permalink
Add macros to generate matrix functions for symmetric and Hermitian m…
Browse files Browse the repository at this point in the history
…atrices, fix type returned by logm and sqrtm
  • Loading branch information
mfasi committed Aug 10, 2015
1 parent 87a956b commit 5d71e05
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 14 deletions.
14 changes: 10 additions & 4 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ expm(x::Number) = exp(x)
## "Functions of Matrices: Theory and Computation", SIAM
function expm!{T<:BlasFloat}(A::StridedMatrix{T})
n = chksquare(A)
n<2 && return exp(A)
if ishermitian(A)
return full(expm(Hermitian(A)))
end
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
nA = norm(A, 1)
I = eye(T,n)
Expand Down Expand Up @@ -278,10 +280,12 @@ function rcswap!{T<:Number}(i::Integer, j::Integer, X::StridedMatrix{T})
end
end

expm(x::Number) = exp(x)

function logm(A::StridedMatrix)
# If possible, use diagonalization
if ishermitian(A)
return logm(Hermitian(A))
return full(logm(Hermitian(A)))
end

# Use Schur decomposition
Expand Down Expand Up @@ -313,12 +317,13 @@ function logm(A::StridedMatrix)
return retmat
end
end

logm(a::Number) = (b = log(complex(a)); imag(b) == 0 ? real(b) : b)
logm(a::Complex) = log(a)

function sqrtm{T<:Real}(A::StridedMatrix{T})
if issym(A)
return sqrtm(Symmetric(A))
return full(sqrtm(Symmetric(A)))
end
n = chksquare(A)
SchurF = schurfact(complex(A))
Expand All @@ -328,13 +333,14 @@ function sqrtm{T<:Real}(A::StridedMatrix{T})
end
function sqrtm{T<:Complex}(A::StridedMatrix{T})
if ishermitian(A)
return sqrtm(Hermitian(A))
return full(sqrtm(Hermitian(A)))
end
n = chksquare(A)
SchurF = schurfact(A)
R = full(sqrtm(UpperTriangular(SchurF[:T])))
SchurF[:vectors]*R*SchurF[:vectors]'
end

sqrtm(a::Number) = (b = sqrt(complex(a)); imag(b) == 0 ? real(b) : b)
sqrtm(a::Complex) = sqrt(a)

Expand Down
2 changes: 1 addition & 1 deletion base/linalg/eigen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function getindex(A::Union{Eigen,GeneralizedEigen}, d::Symbol)
throw(KeyError(d))
end

isposdef(A::Union{Eigen,GeneralizedEigen}) = all(A.values .> 0)
isposdef(A::Union{Eigen,GeneralizedEigen}) = isreal(A.values) && all(A.values .> 0)

function eigfact!{T<:BlasReal}(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true)
n = size(A, 2)
Expand Down
61 changes: 52 additions & 9 deletions base/linalg/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,57 @@ function svdvals!{T<:Real,S}(A::Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{
end

#Matrix-valued functions
expm{T<:Real}(A::RealHermSymComplexHerm{T}) = (F = eigfact(A); F.vectors*Diagonal(exp(F.values))*F.vectors')
function logm{T<:Real}(A::RealHermSymComplexHerm{T})
F = eigfact(A)
isposdef(F) && return F.vectors*Diagonal(log(F.values))*F.vectors'
return F.vectors*Diagonal(log(complex(F.values)))*F.vectors'
function expm(A::Symmetric)
F = eigfact(full(A))
return Symmetric((F.vectors * Diagonal(exp(F.values))) * F.vectors')
end
function sqrtm{T<:Real}(A::RealHermSymComplexHerm{T})
F = eigfact(A)
isposdef(F) && return F.vectors*Diagonal(sqrt(F.values))*F.vectors'
return F.vectors*Diagonal(sqrt(complex(F.values)))*F.vectors'
function expm{T}(A::Hermitian{T})
n = chksquare(A)
F = eigfact(full(A))
retmat = (F.vectors * Diagonal(exp(F.values))) * F.vectors'
if T <: Real
return real(Hermitian(retmat))
else
for i = 1:n
retmat[i,i] = real(retmat[i,i])
end
return Hermitian(retmat)
end
end

for (funm, func) in ([:logm,:log], [:sqrtm,:sqrt])

@eval begin

function ($funm)(A::Symmetric)
F = eigfact(full(A))
if isposdef(F)
retmat = (F.vectors * Diagonal(($func)(F.values))) * F.vectors'
else
retmat = (F.vectors * Diagonal(($func)(complex(F.values)))) * F.vectors'
end
return Symmetric(retmat)
end

function ($funm){T}(A::Hermitian{T})
n = chksquare(A)
F = eigfact(A)
if isposdef(F)
retmat = (F.vectors * Diagonal(($func)(F.values))) * F.vectors'
if T <: Real
return Hermitian(retmat)
else
for i = 1:n
retmat[i,i] = real(retmat[i,i])
end
return Hermitian(retmat)
end
else
retmat = (F.vectors * Diagonal(($func)(complex(F.values)))) * F.vectors'
return retmat
end
end

end

end

0 comments on commit 5d71e05

Please sign in to comment.