Skip to content

Commit

Permalink
Merge pull request #12408 from mfasi/matfun_special_matrices
Browse files Browse the repository at this point in the history
Consistency of return type for matrix functions
  • Loading branch information
andreasnoack committed Aug 10, 2015
2 parents 87a956b + 5d71e05 commit 79604f2
Show file tree
Hide file tree
Showing 3 changed files with 63 additions and 14 deletions.
14 changes: 10 additions & 4 deletions base/linalg/dense.jl
Original file line number Diff line number Diff line change
Expand Up @@ -189,7 +189,9 @@ expm(x::Number) = exp(x)
## "Functions of Matrices: Theory and Computation", SIAM
function expm!{T<:BlasFloat}(A::StridedMatrix{T})
n = chksquare(A)
n<2 && return exp(A)
if ishermitian(A)
return full(expm(Hermitian(A)))
end
ilo, ihi, scale = LAPACK.gebal!('B', A) # modifies A
nA = norm(A, 1)
I = eye(T,n)
Expand Down Expand Up @@ -278,10 +280,12 @@ function rcswap!{T<:Number}(i::Integer, j::Integer, X::StridedMatrix{T})
end
end

expm(x::Number) = exp(x)

function logm(A::StridedMatrix)
# If possible, use diagonalization
if ishermitian(A)
return logm(Hermitian(A))
return full(logm(Hermitian(A)))
end

# Use Schur decomposition
Expand Down Expand Up @@ -313,12 +317,13 @@ function logm(A::StridedMatrix)
return retmat
end
end

logm(a::Number) = (b = log(complex(a)); imag(b) == 0 ? real(b) : b)
logm(a::Complex) = log(a)

function sqrtm{T<:Real}(A::StridedMatrix{T})
if issym(A)
return sqrtm(Symmetric(A))
return full(sqrtm(Symmetric(A)))
end
n = chksquare(A)
SchurF = schurfact(complex(A))
Expand All @@ -328,13 +333,14 @@ function sqrtm{T<:Real}(A::StridedMatrix{T})
end
function sqrtm{T<:Complex}(A::StridedMatrix{T})
if ishermitian(A)
return sqrtm(Hermitian(A))
return full(sqrtm(Hermitian(A)))
end
n = chksquare(A)
SchurF = schurfact(A)
R = full(sqrtm(UpperTriangular(SchurF[:T])))
SchurF[:vectors]*R*SchurF[:vectors]'
end

sqrtm(a::Number) = (b = sqrt(complex(a)); imag(b) == 0 ? real(b) : b)
sqrtm(a::Complex) = sqrt(a)

Expand Down
2 changes: 1 addition & 1 deletion base/linalg/eigen.jl
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@ function getindex(A::Union{Eigen,GeneralizedEigen}, d::Symbol)
throw(KeyError(d))
end

isposdef(A::Union{Eigen,GeneralizedEigen}) = all(A.values .> 0)
isposdef(A::Union{Eigen,GeneralizedEigen}) = isreal(A.values) && all(A.values .> 0)

function eigfact!{T<:BlasReal}(A::StridedMatrix{T}; permute::Bool=true, scale::Bool=true)
n = size(A, 2)
Expand Down
61 changes: 52 additions & 9 deletions base/linalg/symmetric.jl
Original file line number Diff line number Diff line change
Expand Up @@ -126,14 +126,57 @@ function svdvals!{T<:Real,S}(A::Union{Hermitian{T,S}, Symmetric{T,S}, Hermitian{
end

#Matrix-valued functions
expm{T<:Real}(A::RealHermSymComplexHerm{T}) = (F = eigfact(A); F.vectors*Diagonal(exp(F.values))*F.vectors')
function logm{T<:Real}(A::RealHermSymComplexHerm{T})
F = eigfact(A)
isposdef(F) && return F.vectors*Diagonal(log(F.values))*F.vectors'
return F.vectors*Diagonal(log(complex(F.values)))*F.vectors'
function expm(A::Symmetric)
F = eigfact(full(A))
return Symmetric((F.vectors * Diagonal(exp(F.values))) * F.vectors')
end
function sqrtm{T<:Real}(A::RealHermSymComplexHerm{T})
F = eigfact(A)
isposdef(F) && return F.vectors*Diagonal(sqrt(F.values))*F.vectors'
return F.vectors*Diagonal(sqrt(complex(F.values)))*F.vectors'
function expm{T}(A::Hermitian{T})
n = chksquare(A)
F = eigfact(full(A))
retmat = (F.vectors * Diagonal(exp(F.values))) * F.vectors'
if T <: Real
return real(Hermitian(retmat))
else
for i = 1:n
retmat[i,i] = real(retmat[i,i])
end
return Hermitian(retmat)
end
end

for (funm, func) in ([:logm,:log], [:sqrtm,:sqrt])

@eval begin

function ($funm)(A::Symmetric)
F = eigfact(full(A))
if isposdef(F)
retmat = (F.vectors * Diagonal(($func)(F.values))) * F.vectors'
else
retmat = (F.vectors * Diagonal(($func)(complex(F.values)))) * F.vectors'
end
return Symmetric(retmat)
end

function ($funm){T}(A::Hermitian{T})
n = chksquare(A)
F = eigfact(A)
if isposdef(F)
retmat = (F.vectors * Diagonal(($func)(F.values))) * F.vectors'
if T <: Real
return Hermitian(retmat)
else
for i = 1:n
retmat[i,i] = real(retmat[i,i])
end
return Hermitian(retmat)
end
else
retmat = (F.vectors * Diagonal(($func)(complex(F.values)))) * F.vectors'
return retmat
end
end

end

end

0 comments on commit 79604f2

Please sign in to comment.