From debbf4fcd972843ae3b74c38fc9fef87fa7fc116 Mon Sep 17 00:00:00 2001 From: Guillaume Dalle <22795598+gdalle@users.noreply.github.com> Date: Mon, 30 May 2022 23:00:31 +0200 Subject: [PATCH] Make suffstats and fit_mle type generic for Normal{T} --- src/univariate/continuous/normal.jl | 32 ++++++++++++++--------------- 1 file changed, 16 insertions(+), 16 deletions(-) diff --git a/src/univariate/continuous/normal.jl b/src/univariate/continuous/normal.jl index a4007c1e3..34261f22c 100644 --- a/src/univariate/continuous/normal.jl +++ b/src/univariate/continuous/normal.jl @@ -105,14 +105,14 @@ rand(rng::AbstractRNG, d::Normal{T}) where {T} = d.μ + d.σ * randn(rng, float( #### Fitting -struct NormalStats <: SufficientStats - s::Float64 # (weighted) sum of x - m::Float64 # (weighted) mean of x - s2::Float64 # (weighted) sum of (x - μ)^2 - tw::Float64 # total sample weight +struct NormalStats{T} <: SufficientStats + s::T # (weighted) sum of x + m::T # (weighted) mean of x + s2::T # (weighted) sum of (x - μ)^2 + tw::T # total sample weight end -function suffstats(::Type{<:Normal}, x::AbstractArray{T}) where T<:Real +function suffstats(::Type{Normal{T}}, x::AbstractArray{<:Real}) where {T<:Real} n = length(x) # compute s @@ -128,10 +128,10 @@ function suffstats(::Type{<:Normal}, x::AbstractArray{T}) where T<:Real @inbounds s2 += abs2(x[i] - m) end - NormalStats(s, m, s2, n) + NormalStats{T}(s, m, s2, n) end -function suffstats(::Type{<:Normal}, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Real +function suffstats(::Type{Normal{T}}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real}) where {T<:Real} n = length(x) # compute s @@ -150,7 +150,7 @@ function suffstats(::Type{<:Normal}, x::AbstractArray{T}, w::AbstractArray{Float @inbounds s2 += w[i] * abs2(x[i] - m) end - NormalStats(s, m, s2, tw) + NormalStats{T}(s, m, s2, tw) end # Cases where μ or σ is known @@ -211,16 +211,16 @@ end # fit_mle based on sufficient statistics -fit_mle(::Type{<:Normal}, ss::NormalStats) = Normal(ss.m, sqrt(ss.s2 / ss.tw)) +fit_mle(::Type{D}, ss::NormalStats) where {D<:Normal} = D(ss.m, sqrt(ss.s2 / ss.tw)) fit_mle(g::NormalKnownMu, ss::NormalKnownMuStats) = Normal(g.μ, sqrt(ss.s2 / ss.tw)) fit_mle(g::NormalKnownSigma, ss::NormalKnownSigmaStats) = Normal(ss.sx / ss.tw, g.σ) # generic fit_mle methods -function fit_mle(::Type{<:Normal}, x::AbstractArray{T}; mu::Float64=NaN, sigma::Float64=NaN) where T<:Real +function fit_mle(::Type{D}, x::AbstractArray{<:Real}; mu::Float64=NaN, sigma::Float64=NaN) where {D<:Normal} if isnan(mu) if isnan(sigma) - fit_mle(Normal, suffstats(Normal, x)) + fit_mle(D, suffstats(Normal, x)) else g = NormalKnownSigma(sigma) fit_mle(g, suffstats(g, x)) @@ -230,15 +230,15 @@ function fit_mle(::Type{<:Normal}, x::AbstractArray{T}; mu::Float64=NaN, sigma:: g = NormalKnownMu(mu) fit_mle(g, suffstats(g, x)) else - Normal(mu, sigma) + D(mu, sigma) end end end -function fit_mle(::Type{<:Normal}, x::AbstractArray{T}, w::AbstractArray{Float64}; mu::Float64=NaN, sigma::Float64=NaN) where T<:Real +function fit_mle(::Type{D}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real}; mu::Float64=NaN, sigma::Float64=NaN) where {D<:Normal} if isnan(mu) if isnan(sigma) - fit_mle(Normal, suffstats(Normal, x, w)) + fit_mle(D, suffstats(Normal, x, w)) else g = NormalKnownSigma(sigma) fit_mle(g, suffstats(g, x, w)) @@ -248,7 +248,7 @@ function fit_mle(::Type{<:Normal}, x::AbstractArray{T}, w::AbstractArray{Float64 g = NormalKnownMu(mu) fit_mle(g, suffstats(g, x, w)) else - Normal(mu, sigma) + D(mu, sigma) end end end