Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add NamedTupleVariate and ProductNamedTupleDistribution #1803

Merged
merged 46 commits into from
Jan 17, 2025
Merged
Show file tree
Hide file tree
Changes from 42 commits
Commits
Show all changes
46 commits
Select commit Hold shift + click to select a range
142380b
Add NamedTupleVariate
sethaxen Nov 23, 2023
191ca1a
Add ProductNamedTupleDistribution
sethaxen Nov 23, 2023
399b03b
Correctly implement eltype
sethaxen Nov 23, 2023
eb946a8
Simplify insupport implementation
sethaxen Nov 23, 2023
32ca2f0
Overload std for ProductNamedTupleDistribution
sethaxen Nov 23, 2023
a416f02
Simplify rand for ProductNamedTupleDistribution
sethaxen Nov 23, 2023
7deff94
Reformat line
sethaxen Nov 23, 2023
978b2de
Add docstring to ProductNamedTupleDistribution
sethaxen Nov 23, 2023
b718e59
Add marginal API function
sethaxen Nov 23, 2023
d08431d
Add marginal for ProductDistribution
sethaxen Nov 23, 2023
79e5d59
Rearrange marginal
sethaxen Nov 23, 2023
52fb9a0
Allow tuple indexing via marginal
sethaxen Nov 23, 2023
1509abd
Make logpdf type-stable
sethaxen Nov 23, 2023
450fb7d
Add loglikelihood
sethaxen Nov 23, 2023
eb2ed6c
Support extrema for multivariate distributions
sethaxen Nov 23, 2023
e3a0814
Add tests
sethaxen Nov 23, 2023
9acc869
Improve type-inferrability
sethaxen Nov 23, 2023
6d8df2a
Remove extension
sethaxen Nov 23, 2023
d115441
Merge branch 'master' into namedtuplevariate
sethaxen May 27, 2024
9f19a2e
Merge branch 'master' into namedtuplevariate
devmotion Jul 14, 2024
800de5b
Apply suggestions from code review
sethaxen Jul 15, 2024
0b83587
Remove marginal
sethaxen Jul 15, 2024
ba03eea
Add sampler for product namedtuple
sethaxen Jul 15, 2024
1712be6
Use ProductNamedTupleSampler for array rand calls
sethaxen Jul 15, 2024
1056d0d
Add docs page for product distributions
sethaxen Aug 19, 2024
58937fd
Fix typo
sethaxen Aug 19, 2024
d7fd842
Fix ProductNamedTuple docstring
sethaxen Aug 19, 2024
c8b1602
Add deprecation warning to Product docstring
sethaxen Aug 19, 2024
db029c5
Move multivariate product distributions to own page
sethaxen Aug 19, 2024
2634adb
Document NamedTuple products
sethaxen Aug 19, 2024
eb5b176
Add docs index
sethaxen Aug 19, 2024
3ebc3ba
Document usage of ProductNamedTuple
sethaxen Aug 19, 2024
f0dd8c4
Load Distributions for jldoctest
sethaxen Aug 19, 2024
121dd2b
Apply suggestions from code review
sethaxen Sep 4, 2024
a86cac4
Call method on NamedTuple
sethaxen Sep 4, 2024
46fdcfc
Revert to typejoin based eltype
sethaxen Sep 5, 2024
54b0d03
Explicitly check eltype of dist matches that of draw
sethaxen Sep 5, 2024
fe284b1
Correctly compute eltype for nested prod namedtuple distributions
sethaxen Sep 5, 2024
96ccc99
Merge branch 'master' into namedtuplevariate
sethaxen Sep 5, 2024
1eabd23
Revert "Call method on NamedTuple"
sethaxen Sep 5, 2024
28a7c00
Update test/namedtuple/productnamedtuple.jl
sethaxen Sep 5, 2024
f7ab7c0
Merge branch 'master' into namedtuplevariate
sethaxen Jan 5, 2025
935b5b2
Merge branch 'master' into namedtuplevariate
sethaxen Jan 16, 2025
665ab52
Support permutations of NamedTuple fields
sethaxen Jan 16, 2025
8188d35
Fix formatting
sethaxen Jan 16, 2025
d33c31d
Support permutations of names in kldivergence
sethaxen Jan 16, 2025
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions docs/make.jl
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@ makedocs(
"reshape.md",
"cholesky.md",
"mixture.md",
"product.md",
"order_statistics.md",
"convolution.md",
"fit.md",
Expand Down
10 changes: 0 additions & 10 deletions docs/src/multivariate.md
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ MvNormalCanon
MvLogitNormal
MvLogNormal
Dirichlet
Product
```

## Addition Methods
Expand Down Expand Up @@ -105,15 +104,6 @@ params{D<:Distributions.AbstractMvLogNormal}(::Type{D},m::AbstractVector,S::Abst
Distributions._logpdf(d::MultivariateDistribution, x::AbstractArray)
```

## Product distributions

```@docs
Distributions.product_distribution
```

Using `product_distribution` is advised to construct product distributions.
For some distributions, it constructs a special multivariate type.

## Index

```@index
Expand Down
27 changes: 27 additions & 0 deletions docs/src/product.md
Original file line number Diff line number Diff line change
@@ -0,0 +1,27 @@
# Product Distributions

Product distributions are joint distributions of multiple independent distributions.
It is recommended to use `product_distribution` to construct product distributions.
Depending on the type of the argument, it may construct a different distribution type.

## Multivariate products

```@docs
Distributions.product_distribution(::AbstractArray{<:Distribution{<:ArrayLikeVariate}})
Distributions.product_distribution(::AbstractVector{<:Normal})
Distributions.ProductDistribution
Distributions.Product
```

## NamedTuple-variate products

```@docs
Distributions.product_distribution(::NamedTuple{<:Any,<:Tuple{Distribution,Vararg{Distribution}}})
Distributions.ProductNamedTupleDistribution
```

## Index

```@index
Pages = ["product.md"]
```
2 changes: 2 additions & 0 deletions src/Distributions.jl
Original file line number Diff line number Diff line change
Expand Up @@ -41,6 +41,7 @@ export
Multivariate,
Matrixvariate,
CholeskyVariate,
NamedTupleVariate,
Discrete,
Continuous,
Sampleable,
Expand Down Expand Up @@ -296,6 +297,7 @@ include("univariates.jl")
include("edgeworth.jl")
include("multivariates.jl")
include("matrixvariates.jl")
include("namedtuple/productnamedtuple.jl")
include("cholesky/lkjcholesky.jl")
include("samplers.jl")

Expand Down
6 changes: 6 additions & 0 deletions src/common.jl
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,12 @@ const Univariate = ArrayLikeVariate{0}
const Multivariate = ArrayLikeVariate{1}
const Matrixvariate = ArrayLikeVariate{2}

"""
`F <: NamedTupleVariate{K}` specifies that the variate or a sample is of type
`NamedTuple{K}`.
"""
struct NamedTupleVariate{K} <: VariateForm end

"""
`F <: CholeskyVariate` specifies that the variate or a sample is of type
`LinearAlgebra.Cholesky`.
Expand Down
4 changes: 4 additions & 0 deletions src/multivariate/product.jl
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ An N dimensional `MultivariateDistribution` constructed from a vector of N indep
```julia
Product(Uniform.(rand(10), 1)) # A 10-dimensional Product from 10 independent `Uniform` distributions.
```

!!! note
`Product` is deprecated and will be removed in the next breaking release.
Use [`product_distribution`](@ref) instead.
"""
struct Product{
S<:ValueSupport,
Expand Down
155 changes: 155 additions & 0 deletions src/namedtuple/productnamedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,155 @@
"""
ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <:
Distribution{NamedTupleVariate{Tnames},S}

A distribution of `NamedTuple`s, constructed from a `NamedTuple` of independent named
distributions.

Users should use [`product_distribution`](@ref) to construct a product distribution of
independent distributions instead of constructing a `ProductNamedTupleDistribution`
directly.

# Examples

```jldoctest ProductNamedTuple; setup = :(using Distributions, Random; Random.seed!(832))
julia> d = product_distribution((x=Normal(), y=Dirichlet([2, 4])))
ProductNamedTupleDistribution{(:x, :y)}(
x: Normal{Float64}(μ=0.0, σ=1.0)
y: Dirichlet{Int64, Vector{Int64}, Float64}(alpha=[2, 4])
)


julia> nt = rand(d)
(x = 1.5155385995160346, y = [0.533531876438439, 0.466468123561561])

julia> pdf(d, nt)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
0.13702825691074877

julia> mode(d) # mode of marginals
(x = 0.0, y = [0.25, 0.75])

julia> mean(d) # mean of marginals
(x = 0.0, y = [0.3333333333333333, 0.6666666666666666])

julia> var(d) # var of marginals
(x = 1.0, y = [0.031746031746031744, 0.031746031746031744])
```
"""
struct ProductNamedTupleDistribution{Tnames,Tdists,S<:ValueSupport,eltypes} <:
Distribution{NamedTupleVariate{Tnames},S}
dists::NamedTuple{Tnames,Tdists}
end
function ProductNamedTupleDistribution(
dists::NamedTuple{K,V}
) where {K,V<:Tuple{Distribution,Vararg{Distribution}}}
vs = _product_valuesupport(values(dists))
eltypes = _product_namedtuple_eltype(values(dists))
return ProductNamedTupleDistribution{K,V,vs,eltypes}(dists)
end

_gentype(d::UnivariateDistribution) = eltype(d)
_gentype(d::Distribution{<:ArrayLikeVariate{S}}) where {S} = Array{eltype(d),S}
function _gentype(d::Distribution{CholeskyVariate})
T = eltype(d)
return LinearAlgebra.Cholesky{T,Matrix{T}}
end
function _gentype(d::ProductNamedTupleDistribution{K}) where {K}
return NamedTuple{K,Tuple{map(_gentype, values(d.dists))...}}
end
_gentype(::Distribution) = Any

_product_namedtuple_eltype(dists) = typejoin(map(_gentype, dists)...)
devmotion marked this conversation as resolved.
Show resolved Hide resolved

function Base.show(io::IO, d::ProductNamedTupleDistribution)
return show_multline(io, d, collect(pairs(d.dists)))
devmotion marked this conversation as resolved.
Show resolved Hide resolved
end

function distrname(::ProductNamedTupleDistribution{K}) where {K}
return "ProductNamedTupleDistribution{$K}"
end

"""
product_distribution(dists::NamedTuple{K,Tuple{Vararg{Distribution}}}) where {K}

Create a distribution of `NamedTuple`s as a product distribution of independent named
distributions.

The function falls back to constructing a [`ProductNamedTupleDistribution`](@ref)
distribution but specialized methods can be defined.
"""
function product_distribution(
dists::NamedTuple{<:Any,<:Tuple{Distribution,Vararg{Distribution}}}
)
return ProductNamedTupleDistribution(dists)
end

# Properties

Base.eltype(::Type{<:ProductNamedTupleDistribution{<:Any,<:Any,<:Any,T}}) where {T} = T

Base.minimum(d::ProductNamedTupleDistribution) = map(minimum, d.dists)

Base.maximum(d::ProductNamedTupleDistribution) = map(maximum, d.dists)

function insupport(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K}
return all(map(insupport, dist.dists, x))
end

# Evaluation

function pdf(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K}
return exp(logpdf(dist, x))
end

function logpdf(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K}
return sum(map(logpdf, dist.dists, x))
end

function loglikelihood(dist::ProductNamedTupleDistribution{K}, x::NamedTuple{K}) where {K}
return logpdf(dist, x)
end

function loglikelihood(
dist::ProductNamedTupleDistribution{K}, xs::AbstractArray{<:NamedTuple{K}}
) where {K}
return sum(Base.Fix1(loglikelihood, dist), xs)
end

# Statistics

mode(d::ProductNamedTupleDistribution) = map(mode, d.dists)

mean(d::ProductNamedTupleDistribution) = map(mean, d.dists)

var(d::ProductNamedTupleDistribution) = map(var, d.dists)

std(d::ProductNamedTupleDistribution) = map(std, d.dists)

entropy(d::ProductNamedTupleDistribution) = sum(entropy, values(d.dists))

function kldivergence(
d1::ProductNamedTupleDistribution{K}, d2::ProductNamedTupleDistribution{K}
) where {K}
return sum(map(kldivergence, d1.dists, d2.dists))
end

# Sampling

function sampler(d::ProductNamedTupleDistribution{K,<:Any,S}) where {K,S}
samplers = map(sampler, d.dists)
Tsamplers = typeof(values(samplers))
return ProductNamedTupleSampler{K,Tsamplers,S}(samplers)
end

function Base.rand(rng::AbstractRNG, d::ProductNamedTupleDistribution{K}) where {K}
return NamedTuple{K}(map(Base.Fix1(rand, rng), d.dists))
end
function Base.rand(
rng::AbstractRNG, d::ProductNamedTupleDistribution{K}, dims::Dims
) where {K}
return convert(AbstractArray{<:NamedTuple{K}}, _rand(rng, sampler(d), dims))
end

function _rand!(rng::AbstractRNG, d::ProductNamedTupleDistribution, xs::AbstractArray)
sethaxen marked this conversation as resolved.
Show resolved Hide resolved
return _rand!(rng, sampler(d), xs)
end
4 changes: 3 additions & 1 deletion src/samplers.jl
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,9 @@ for fname in ["aliastable.jl",
"vonmises.jl",
"vonmisesfisher.jl",
"discretenonparametric.jl",
"categorical.jl"]
"categorical.jl",
"productnamedtuple.jl",
]

include(joinpath("samplers", fname))
end
21 changes: 21 additions & 0 deletions src/samplers/productnamedtuple.jl
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
struct ProductNamedTupleSampler{Tnames,Tsamplers,S<:ValueSupport} <:
Sampleable{NamedTupleVariate{Tnames},S}
samplers::NamedTuple{Tnames,Tsamplers}
end

function Base.rand(rng::AbstractRNG, spl::ProductNamedTupleSampler{K}) where {K}
return NamedTuple{K}(map(Base.Fix1(rand, rng), spl.samplers))
end

function _rand(rng::AbstractRNG, spl::ProductNamedTupleSampler, dims::Dims)
return map(CartesianIndices(dims)) do _
return rand(rng, spl)
end
end

function _rand!(rng::AbstractRNG, spl::ProductNamedTupleSampler, xs::AbstractArray)
for i in eachindex(xs)
xs[i] = rand(rng, spl)
end
return xs
end
Loading
Loading