Skip to content

Commit

Permalink
Make suffstats and fit_mle type generic for Normal{T}
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed May 30, 2022
1 parent f889f9e commit debbf4f
Showing 1 changed file with 16 additions and 16 deletions.
32 changes: 16 additions & 16 deletions src/univariate/continuous/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -105,14 +105,14 @@ rand(rng::AbstractRNG, d::Normal{T}) where {T} = d.μ + d.σ * randn(rng, float(

#### Fitting

struct NormalStats <: SufficientStats
s::Float64 # (weighted) sum of x
m::Float64 # (weighted) mean of x
s2::Float64 # (weighted) sum of (x - μ)^2
tw::Float64 # total sample weight
struct NormalStats{T} <: SufficientStats
s::T # (weighted) sum of x
m::T # (weighted) mean of x
s2::T # (weighted) sum of (x - μ)^2
tw::T # total sample weight
end

function suffstats(::Type{<:Normal}, x::AbstractArray{T}) where T<:Real
function suffstats(::Type{Normal{T}}, x::AbstractArray{<:Real}) where {T<:Real}
n = length(x)

# compute s
Expand All @@ -128,10 +128,10 @@ function suffstats(::Type{<:Normal}, x::AbstractArray{T}) where T<:Real
@inbounds s2 += abs2(x[i] - m)
end

NormalStats(s, m, s2, n)
NormalStats{T}(s, m, s2, n)
end

function suffstats(::Type{<:Normal}, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Real
function suffstats(::Type{Normal{T}}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real}) where {T<:Real}
n = length(x)

# compute s
Expand All @@ -150,7 +150,7 @@ function suffstats(::Type{<:Normal}, x::AbstractArray{T}, w::AbstractArray{Float
@inbounds s2 += w[i] * abs2(x[i] - m)
end

NormalStats(s, m, s2, tw)
NormalStats{T}(s, m, s2, tw)
end

# Cases where μ or σ is known
Expand Down Expand Up @@ -211,16 +211,16 @@ end

# fit_mle based on sufficient statistics

fit_mle(::Type{<:Normal}, ss::NormalStats) = Normal(ss.m, sqrt(ss.s2 / ss.tw))
fit_mle(::Type{D}, ss::NormalStats) where {D<:Normal} = D(ss.m, sqrt(ss.s2 / ss.tw))
fit_mle(g::NormalKnownMu, ss::NormalKnownMuStats) = Normal(g.μ, sqrt(ss.s2 / ss.tw))
fit_mle(g::NormalKnownSigma, ss::NormalKnownSigmaStats) = Normal(ss.sx / ss.tw, g.σ)

# generic fit_mle methods

function fit_mle(::Type{<:Normal}, x::AbstractArray{T}; mu::Float64=NaN, sigma::Float64=NaN) where T<:Real
function fit_mle(::Type{D}, x::AbstractArray{<:Real}; mu::Float64=NaN, sigma::Float64=NaN) where {D<:Normal}
if isnan(mu)
if isnan(sigma)
fit_mle(Normal, suffstats(Normal, x))
fit_mle(D, suffstats(Normal, x))
else
g = NormalKnownSigma(sigma)
fit_mle(g, suffstats(g, x))
Expand All @@ -230,15 +230,15 @@ function fit_mle(::Type{<:Normal}, x::AbstractArray{T}; mu::Float64=NaN, sigma::
g = NormalKnownMu(mu)
fit_mle(g, suffstats(g, x))
else
Normal(mu, sigma)
D(mu, sigma)
end
end
end

function fit_mle(::Type{<:Normal}, x::AbstractArray{T}, w::AbstractArray{Float64}; mu::Float64=NaN, sigma::Float64=NaN) where T<:Real
function fit_mle(::Type{D}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real}; mu::Float64=NaN, sigma::Float64=NaN) where {D<:Normal}
if isnan(mu)
if isnan(sigma)
fit_mle(Normal, suffstats(Normal, x, w))
fit_mle(D, suffstats(Normal, x, w))
else
g = NormalKnownSigma(sigma)
fit_mle(g, suffstats(g, x, w))
Expand All @@ -248,7 +248,7 @@ function fit_mle(::Type{<:Normal}, x::AbstractArray{T}, w::AbstractArray{Float64
g = NormalKnownMu(mu)
fit_mle(g, suffstats(g, x, w))
else
Normal(mu, sigma)
D(mu, sigma)
end
end
end

0 comments on commit debbf4f

Please sign in to comment.