From a1df819aa556a2bb269a5f24ed9138807f27d898 Mon Sep 17 00:00:00 2001 From: Tim Reichelt Date: Fri, 15 Nov 2019 16:03:19 +0000 Subject: [PATCH] Fix MvNormal (#995) (#1013) --- src/multivariate/mvnormal.jl | 2 ++ test/mvnormal.jl | 2 ++ 2 files changed, 4 insertions(+) diff --git a/src/multivariate/mvnormal.jl b/src/multivariate/mvnormal.jl index f6e98c5b9..830ddf64a 100644 --- a/src/multivariate/mvnormal.jl +++ b/src/multivariate/mvnormal.jl @@ -216,6 +216,8 @@ function MvNormal(μ::AbstractVector{<:Real}, σ::UniformScaling{<:Real}) MvNormal(convert(AbstractArray{R}, μ), R(σ.λ)) end MvNormal(Σ::Matrix{<:Real}) = MvNormal(PDMat(Σ)) +MvNormal(Σ::Union{Symmetric{<:Real}, Hermitian{<:Real}}) = MvNormal(PDMat(Σ)) +MvNormal(Σ::Diagonal{<:Real}) = MvNormal(PDiagMat(diag(Σ))) MvNormal(σ::Vector{<:Real}) = MvNormal(PDiagMat(abs2.(σ))) MvNormal(d::Int, σ::Real) = MvNormal(ScalMat(d, abs2(σ))) diff --git a/test/mvnormal.jl b/test/mvnormal.jl index a933b30ff..5ede92781 100644 --- a/test/mvnormal.jl +++ b/test/mvnormal.jl @@ -106,6 +106,8 @@ end (MvNormal(mu, C), mu, C), (MvNormal(mu_r, C), mu_r, C), (MvNormal(C), zeros(3), C), + (MvNormal(Symmetric(C)), zeros(3), Matrix(Symmetric(C))), + (MvNormal(Diagonal(dv)), zeros(3), Matrix(Diagonal(dv))), (MvNormalCanon(h, 2.0), h ./ 2.0, Matrix(0.5I, 3, 3)), (MvNormalCanon(mu_r, 2.0), mu_r ./ 2.0, Matrix(0.5I, 3, 3)), (MvNormalCanon(3, 2.0), zeros(3), Matrix(0.5I, 3, 3)),