Skip to content

Commit

Permalink
Type generic partially known Normal
Browse files Browse the repository at this point in the history
  • Loading branch information
gdalle committed Jun 9, 2023
1 parent 7443d7b commit 4b27b30
Showing 1 changed file with 48 additions and 31 deletions.
79 changes: 48 additions & 31 deletions src/univariate/continuous/normal.jl
Original file line number Diff line number Diff line change
Expand Up @@ -125,9 +125,13 @@ struct NormalStats{T} <: SufficientStats
m::T # (weighted) mean of x
s2::T # (weighted) sum of (x - μ)^2
tw::T # total sample weight
function NormalStats(s::T1, m::T2, s2::T3, tw::T4) where {T1,T2,T3,T4}
T = promote_type(T1, T2, T3, T4)
return new{T}(T(s), T(m), T(s2), T(tw))
end
end

function suffstats(::Type{Normal{T}}, x::AbstractArray{<:Real}) where {T<:Real}
function suffstats(::Type{<:Normal}, x::AbstractArray{<:Real})
n = length(x)

# compute s
Expand All @@ -143,10 +147,10 @@ function suffstats(::Type{Normal{T}}, x::AbstractArray{<:Real}) where {T<:Real}
@inbounds s2 += abs2(x[i] - m)
end

NormalStats{T}(s, m, s2, n)
NormalStats(s, m, s2, n)
end

function suffstats(::Type{Normal{T}}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real}) where {T<:Real}
function suffstats(::Type{<:Normal}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
n = length(x)

# compute s
Expand All @@ -165,22 +169,26 @@ function suffstats(::Type{Normal{T}}, x::AbstractArray{<:Real}, w::AbstractArray
@inbounds s2 += w[i] * abs2(x[i] - m)
end

NormalStats{T}(s, m, s2, tw)
NormalStats(s, m, s2, tw)
end

# Cases where μ or σ is known

struct NormalKnownMu <: IncompleteDistribution
μ::Float64
struct NormalKnownMu{T} <: IncompleteDistribution
μ::T
end

struct NormalKnownMuStats <: SufficientStats
μ::Float64 # known mean
s2::Float64 # (weighted) sum of (x - μ)^2
tw::Float64 # total sample weight
struct NormalKnownMuStats{T} <: SufficientStats
μ::T # known mean
s2::T # (weighted) sum of (x - μ)^2
tw::T # total sample weight
function NormalKnownMuStats::T1, s2::T2, tw::T3) where {T1,T2,T3}
T = promote_type(T1, T2, T3)
return new{T}(μ, s2, tw)
end
end

function suffstats(g::NormalKnownMu, x::AbstractArray{T}) where T<:Real
function suffstats(g::NormalKnownMu, x::AbstractArray{<:Real})
μ = g.μ
s2 = abs2(x[1] - μ)
for i = 2:length(x)
Expand All @@ -189,7 +197,7 @@ function suffstats(g::NormalKnownMu, x::AbstractArray{T}) where T<:Real
NormalKnownMuStats(g.μ, s2, length(x))
end

function suffstats(g::NormalKnownMu, x::AbstractArray{T}, w::AbstractArray{Float64}) where T<:Real
function suffstats(g::NormalKnownMu, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
μ = g.μ
s2 = abs2(x[1] - μ) * w[1]
tw = w[1]
Expand All @@ -201,26 +209,29 @@ function suffstats(g::NormalKnownMu, x::AbstractArray{T}, w::AbstractArray{Float
NormalKnownMuStats(g.μ, s2, tw)
end

struct NormalKnownSigma <: IncompleteDistribution
σ::Float64

function NormalKnownSigma::Float64)
struct NormalKnownSigma{T} <: IncompleteDistribution
σ::T
function NormalKnownSigma::T) where {T}
σ > 0 || throw(ArgumentError("σ must be a positive value."))
new(σ)
return new{T}(σ)
end
end

struct NormalKnownSigmaStats <: SufficientStats
σ::Float64 # known std.dev
sx::Float64 # (weighted) sum of x
tw::Float64 # total sample weight
struct NormalKnownSigmaStats{T} <: SufficientStats
σ::T # known std.dev
sx::T # (weighted) sum of x
tw::T # total sample weight
function NormalKnownSigmaStats::T1, sx::T2, tw::T3) where {T1,T2,T3}
T = promote_type(T1, T2, T3)
return new{T}(σ, sx, tw)
end
end

function suffstats(g::NormalKnownSigma, x::AbstractArray{T}) where T<:Real
function suffstats(g::NormalKnownSigma, x::AbstractArray{<:Real})
NormalKnownSigmaStats(g.σ, sum(x), Float64(length(x)))
end

function suffstats(g::NormalKnownSigma, x::AbstractArray{T}, w::AbstractArray{T}) where T<:Real
function suffstats(g::NormalKnownSigma, x::AbstractArray{<:Real}, w::AbstractArray{<:Real})
NormalKnownSigmaStats(g.σ, dot(x, w), sum(w))
end

Expand All @@ -232,16 +243,19 @@ fit_mle(g::NormalKnownSigma, ss::NormalKnownSigmaStats) = Normal(ss.sx / ss.tw,

# generic fit_mle methods

function fit_mle(::Type{D}, x::AbstractArray{<:Real}; mu::Float64=NaN, sigma::Float64=NaN) where {D<:Normal}
if isnan(mu)
if isnan(sigma)
function fit_mle(
::Type{D}, x::AbstractArray{<:Real};
mu::Union{Nothing,<:Real}=nothing, sigma::Union{Nothing,<:Real}=nothing
) where {D<:Normal}
if isnothing(mu)
if isnothing(sigma)
fit_mle(D, suffstats(Normal, x))
else
g = NormalKnownSigma(sigma)
fit_mle(g, suffstats(g, x))
end
else
if isnan(sigma)
if isnothing(sigma)
g = NormalKnownMu(mu)
fit_mle(g, suffstats(g, x))
else
Expand All @@ -250,16 +264,19 @@ function fit_mle(::Type{D}, x::AbstractArray{<:Real}; mu::Float64=NaN, sigma::Fl
end
end

function fit_mle(::Type{D}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real}; mu::Float64=NaN, sigma::Float64=NaN) where {D<:Normal}
if isnan(mu)
if isnan(sigma)
function fit_mle(
::Type{D}, x::AbstractArray{<:Real}, w::AbstractArray{<:Real};
mu::Union{Nothing,<:Real}=nothing, sigma::Union{Nothing,<:Real}=nothing
) where {D<:Normal}
if isnothing(mu)
if isnothing(sigma)
fit_mle(D, suffstats(Normal, x, w))
else
g = NormalKnownSigma(sigma)
fit_mle(g, suffstats(g, x, w))
end
else
if isnan(sigma)
if isnothing(sigma)
g = NormalKnownMu(mu)
fit_mle(g, suffstats(g, x, w))
else
Expand Down

0 comments on commit 4b27b30

Please sign in to comment.