Skip to content

Commit

Permalink
Merge pull request #7992 from JuliaLang/anj/sym
Browse files Browse the repository at this point in the history
Make Symmetric/Hermitian parametric on matrix type
  • Loading branch information
andreasnoack committed Aug 14, 2014
2 parents 571eb0d + 370d64d commit eb5b91c
Show file tree
Hide file tree
Showing 2 changed files with 79 additions and 40 deletions.
31 changes: 31 additions & 0 deletions base/linalg/blas.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ export
gbmv,
gemv!,
gemv,
hemv!,
hemv,
sbmv!,
sbmv,
symv!,
Expand Down Expand Up @@ -360,6 +362,35 @@ for (fname, elty) in ((:dsymv_,:Float64),
end
end

### hemv
for (fname, elty) in ((:zhemv_,:Complex128),
(:cgemv_,:Complex64))
@eval begin
function hemv!(uplo::Char, α::$elty, A::StridedMatrix{$elty}, x::StridedVector{$elty}, β::$elty, y::StridedVector{$elty})
n = size(A, 2)
n == length(x) || throw(DimensionMismatch(""))
size(A, 1) == length(y) || throw(DimensionMismatch(""))
lda = max(1, stride(A, 2))
incx = stride(x, 1)
incy = stride(y, 1)
ccall(($fname, libblas), Void,
(Ptr{Uint8}, Ptr{BlasInt}, Ptr{$elty}, Ptr{$elty},
Ptr{BlasInt}, Ptr{$elty}, Ptr{BlasInt}, Ptr{$elty},
Ptr{$elty}, Ptr{BlasInt}),
&uplo, &n, &α, A,
&lda, x, &incx, &β,
y, &incy)
y
end
function hemv(uplo::BlasChar, α::($elty), A::StridedMatrix{$elty}, x::StridedVector{$elty})
hemv!(uplo, α, A, x, zero($elty), similar(x))
end
function hemv(uplo::BlasChar, A::StridedMatrix{$elty}, x::StridedVector{$elty})
hemv(uplo, one($elty), A, x)
end
end
end

### sbmv, (SB) symmetric banded matrix-vector multiplication
for (fname, elty) in ((:dsbmv_,:Float64),
(:ssbmv_,:Float32),
Expand Down
88 changes: 48 additions & 40 deletions base/linalg/symmetric.jl
Original file line number Diff line number Diff line change
@@ -1,63 +1,71 @@
#Symmetric and Hermitian matrices
immutable Symmetric{T} <: AbstractMatrix{T}
S::Matrix{T}
immutable Symmetric{T,S<:AbstractMatrix{T}} <: AbstractMatrix{T}
data::S
uplo::Char
end
Symmetric(A::Matrix, uplo::Symbol=:U) = (chksquare(A);Symmetric(A, string(uplo)[1]))
immutable Hermitian{T} <: AbstractMatrix{T}
S::Matrix{T}
Symmetric(A::AbstractMatrix, uplo::Symbol=:U) = (chksquare(A);Symmetric{eltype(A),typeof(A)}(A, string(uplo)[1]))
immutable Hermitian{T,S<:AbstractMatrix{T}} <: AbstractMatrix{T}
data::S
uplo::Char
end
Hermitian(A::Matrix, uplo::Symbol=:U) = (chksquare(A);Hermitian(A, string(uplo)[1]))
typealias HermOrSym{T} Union(Hermitian{T}, Symmetric{T})
typealias RealHermSymComplexHerm{T<:Real} Union(Hermitian{T}, Symmetric{T}, Hermitian{Complex{T}})
Hermitian(A::AbstractMatrix, uplo::Symbol=:U) = (chksquare(A);Hermitian{eltype(A),typeof(A)}(A, string(uplo)[1]))
typealias HermOrSym{T,S} Union(Hermitian{T,S}, Symmetric{T,S})
typealias RealHermSymComplexHerm{T<:Real,S} Union(Hermitian{T,S}, Symmetric{T,S}, Hermitian{Complex{T},S})

size(A::HermOrSym, args...) = size(A.S, args...)
getindex(A::Symmetric, i::Integer, j::Integer) = (A.uplo == 'U') == (i < j) ? getindex(A.S, i, j) : getindex(A.S, j, i)
getindex(A::Hermitian, i::Integer, j::Integer) = (A.uplo == 'U') == (i < j) ? getindex(A.S, i, j) : conj(getindex(A.S, j, i))
full(A::Symmetric) = copytri!(copy(A.S), A.uplo)
full(A::Hermitian) = copytri!(copy(A.S), A.uplo, true)
convert{T}(::Type{Symmetric{T}},A::Symmetric) = Symmetric(convert(Matrix{T},A.S), A.uplo)
convert{T}(::Type{AbstractMatrix{T}}, A::Symmetric) = Symmetric(convert(AbstractMatrix{T}, A.S), A.uplo)
convert{T}(::Type{Hermitian{T}},A::Hermitian) = Hermitian(convert(Matrix{T},A.S), A.uplo)
convert{T}(::Type{AbstractMatrix{T}}, A::Hermitian) = Hermitian(convert(AbstractMatrix{T}, A.S), A.uplo)
copy(A::Symmetric) = Symmetric(copy(A.S),A.uplo)
copy(A::Hermitian) = Hermitian(copy(A.S),A.uplo)
size(A::HermOrSym, args...) = size(A.data, args...)
getindex(A::Symmetric, i::Integer, j::Integer) = (A.uplo == 'U') == (i < j) ? getindex(A.data, i, j) : getindex(A.data, j, i)
getindex(A::Hermitian, i::Integer, j::Integer) = (A.uplo == 'U') == (i < j) ? getindex(A.data, i, j) : conj(getindex(A.data, j, i))
full(A::Symmetric) = copytri!(copy(A.data), A.uplo)
full(A::Hermitian) = copytri!(copy(A.data), A.uplo, true)
convert{T,S}(::Type{Symmetric{T,S}},A::Symmetric{T,S}) = A
convert{T,S}(::Type{Symmetric{T,S}},A::Symmetric) = Symmetric{T,S}(convert(S,A.data),A.uplo)
convert{T}(::Type{AbstractMatrix{T}}, A::Symmetric) = Symmetric(convert(AbstractMatrix{T}, A.data), symbol(A.uplo))
convert{T,S}(::Type{Hermitian{T,S}},A::Hermitian{T,S}) = A
convert{T,S}(::Type{Hermitian{T,S}},A::Hermitian) = Hermitian{T,S}(convert(S,A.data),A.uplo)
convert{T}(::Type{AbstractMatrix{T}}, A::Hermitian) = Hermitian(convert(AbstractMatrix{T}, A.data), symbol(A.uplo))
copy{T,S}(A::Symmetric{T,S}) = Symmetric{T,S}(copy(A.data),A.uplo)
copy{T,S}(A::Hermitian{T,S}) = Hermitian{T,S}(copy(A.data),A.uplo)
ishermitian(A::Hermitian) = true
ishermitian{T<:Real}(A::Symmetric{T}) = true
ishermitian{T<:Complex}(A::Symmetric{T}) = all(imag(A.S) .== 0)
issym{T<:Real}(A::Hermitian{T}) = true
issym{T<:Complex}(A::Hermitian{T}) = all(imag(A.S) .== 0)
ishermitian{T<:Real,S}(A::Symmetric{T,S}) = true
ishermitian{T<:Complex,S}(A::Symmetric{T,S}) = all(imag(A.data) .== 0)
issym{T<:Real,S}(A::Hermitian{T,S}) = true
issym{T<:Complex,S}(A::Hermitian{T,S}) = all(imag(A.data) .== 0)
issym(A::Symmetric) = true
transpose(A::Symmetric) = A
ctranspose(A::Hermitian) = A

## Matvec
A_mul_B!{T<:BlasFloat,S<:AbstractMatrix}(y::StridedVector{T}, A::Symmetric{T,S}, x::StridedVector{T}) = BLAS.symv!(A.uplo, one(T), A.data, x, zero(T), y)
A_mul_B!{T<:BlasComplex,S<:AbstractMatrix}(y::StridedVector{T}, A::Hermitian{T,S}, x::StridedVector{T}) = BLAS.hemv!(A.uplo, one(T), A.data, x, zero(T), y)
##Matmat
A_mul_B!{T<:BlasFloat,S<:AbstractMatrix}(C::StridedMatrix{T}, A::Symmetric{T,S}, B::StridedMatrix{T}) = BLAS.symm!(A.uplo, one(T), A.data, B, zero(T), C)
A_mul_B!{T<:BlasComplex,S<:AbstractMatrix}(y::StridedMatrix{T}, A::Hermitian{T,S}, x::StridedMatrix{T}) = BLAS.hemm!(A.uplo, one(T), A.data, B, zero(T), C)

*(A::HermOrSym, B::HermOrSym) = full(A)*full(B)
*(A::HermOrSym, B::StridedMatrix) = full(A)*B
*(A::StridedMatrix, B::HermOrSym) = A*full(B)

factorize(A::HermOrSym) = bkfact(A.S, symbol(A.uplo), issym(A))
\(A::HermOrSym, B::StridedVecOrMat) = \(bkfact(A.S, symbol(A.uplo), issym(A)), B)
factorize(A::HermOrSym) = bkfact(A.data, symbol(A.uplo), issym(A))
\(A::HermOrSym, B::StridedVecOrMat) = \(bkfact(A.data, symbol(A.uplo), issym(A)), B)

eigfact!{T<:BlasReal}(A::RealHermSymComplexHerm{T}) = Eigen(LAPACK.syevr!('V', 'A', A.uplo, A.S, 0.0, 0.0, 0, 0, -1.0)...)
eigfact!{T<:BlasReal}(A::RealHermSymComplexHerm{T}, irange::UnitRange) = Eigen(LAPACK.syevr!('V', 'I', A.uplo, A.S, 0.0, 0.0, irange.start, irange.stop, -1.0)...)
eigfact!{T<:BlasReal}(A::RealHermSymComplexHerm{T}, vl::Real, vh::Real) = Eigen(LAPACK.syevr!('V', 'V', A.uplo, A.S, convert(T, vl), convert(T, vh), 0, 0, -1.0)...)
eigvals!{T<:BlasReal}(A::RealHermSymComplexHerm{T}) = LAPACK.syevr!('N', 'A', A.uplo, A.S, 0.0, 0.0, 0, 0, -1.0)[1]
eigvals!{T<:BlasReal}(A::RealHermSymComplexHerm{T}, irange::UnitRange) = LAPACK.syevr!('N', 'I', A.uplo, A.S, 0.0, 0.0, irange.start, irange.stop, -1.0)[1]
eigvals!{T<:BlasReal}(A::RealHermSymComplexHerm{T}, vl::Real, vh::Real) = LAPACK.syevr!('N', 'V', A.uplo, A.S, convert(T, vl), convert(T, vh), 0, 0, -1.0)[1]
eigmax{T<:Real}(A::RealHermSymComplexHerm{T}) = eigvals(A, size(A, 1):size(A, 1))[1]
eigmin{T<:Real}(A::RealHermSymComplexHerm{T}) = eigvals(A, 1:1)[1]
eigfact!{T<:BlasReal,S<:StridedMatrix}(A::RealHermSymComplexHerm{T,S}) = Eigen(LAPACK.syevr!('V', 'A', A.uplo, A.data, 0.0, 0.0, 0, 0, -1.0)...)
eigfact!{T<:BlasReal,S<:StridedMatrix}(A::RealHermSymComplexHerm{T,S}, irange::UnitRange) = Eigen(LAPACK.syevr!('V', 'I', A.uplo, A.data, 0.0, 0.0, irange.start, irange.stop, -1.0)...)
eigfact!{T<:BlasReal,S<:StridedMatrix}(A::RealHermSymComplexHerm{T,S}, vl::Real, vh::Real) = Eigen(LAPACK.syevr!('V', 'V', A.uplo, A.data, convert(T, vl), convert(T, vh), 0, 0, -1.0)...)
eigvals!{T<:BlasReal,S<:StridedMatrix}(A::RealHermSymComplexHerm{T,S}) = LAPACK.syevr!('N', 'A', A.uplo, A.data, 0.0, 0.0, 0, 0, -1.0)[1]
eigvals!{T<:BlasReal,S<:StridedMatrix}(A::RealHermSymComplexHerm{T,S}, irange::UnitRange) = LAPACK.syevr!('N', 'I', A.uplo, A.data, 0.0, 0.0, irange.start, irange.stop, -1.0)[1]
eigvals!{T<:BlasReal,S<:StridedMatrix}(A::RealHermSymComplexHerm{T,S}, vl::Real, vh::Real) = LAPACK.syevr!('N', 'V', A.uplo, A.data, convert(T, vl), convert(T, vh), 0, 0, -1.0)[1]
eigmax{T<:Real,S<:StridedMatrix}(A::RealHermSymComplexHerm{T,S}) = eigvals(A, size(A, 1):size(A, 1))[1]
eigmin{T<:Real,S<:StridedMatrix}(A::RealHermSymComplexHerm{T,S}) = eigvals(A, 1:1)[1]

function eigfact!{T<:BlasReal}(A::HermOrSym{T}, B::HermOrSym{T})
vals, vecs, _ = LAPACK.sygvd!(1, 'V', A.uplo, A.S, B.uplo == A.uplo ? B.S : B.S')
function eigfact!{T<:BlasReal,S<:StridedMatrix}(A::HermOrSym{T,S}, B::HermOrSym{T,S})
vals, vecs, _ = LAPACK.sygvd!(1, 'V', A.uplo, A.data, B.uplo == A.uplo ? B.data : B.data')
GeneralizedEigen(vals, vecs)
end
function eigfact!{T<:BlasComplex}(A::Hermitian{T}, B::Hermitian{T})
vals, vecs, _ = LAPACK.sygvd!(1, 'V', A.uplo, A.S, B.uplo == A.uplo ? B.S : B.S')
function eigfact!{T<:BlasComplex,S<:StridedMatrix}(A::Hermitian{T,S}, B::Hermitian{T,S})
vals, vecs, _ = LAPACK.sygvd!(1, 'V', A.uplo, A.data, B.uplo == A.uplo ? B.data : B.data')
GeneralizedEigen(vals, vecs)
end
eigvals!{T<:BlasReal}(A::HermOrSym{T}, B::HermOrSym{T}) = LAPACK.sygvd!(1, 'N', A.uplo, A.S, B.uplo == A.uplo ? B.S : B.S')[1]
eigvals!{T<:BlasComplex}(A::Hermitian{T}, B::Hermitian{T}) = LAPACK.sygvd!(1, 'N', A.uplo, A.S, B.uplo == A.uplo ? B.S : B.S')[1]
eigvals!{T<:BlasReal,S<:StridedMatrix}(A::HermOrSym{T,S}, B::HermOrSym{T,S}) = LAPACK.sygvd!(1, 'N', A.uplo, A.data, B.uplo == A.uplo ? B.data : B.data')[1]
eigvals!{T<:BlasComplex,S<:StridedMatrix}(A::Hermitian{T,S}, B::Hermitian{T,S}) = LAPACK.sygvd!(1, 'N', A.uplo, A.data, B.uplo == A.uplo ? B.data : B.data')[1]

#Matrix-valued functions
expm{T<:Real}(A::RealHermSymComplexHerm{T}) = (F = eigfact(A); F.vectors*Diagonal(exp(F.values))*F.vectors')
Expand Down

0 comments on commit eb5b91c

Please sign in to comment.