Skip to content

Commit

Permalink
Generalize Multinomial moment to arbitrary dimensions
Browse files Browse the repository at this point in the history
  • Loading branch information
markvrma authored and ricardoV94 committed Mar 18, 2022
1 parent 5a44793 commit a0cff37
Show file tree
Hide file tree
Showing 2 changed files with 20 additions and 31 deletions.
25 changes: 4 additions & 21 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -56,7 +56,7 @@
logpow,
multigammaln,
)
from pymc.distributions.distribution import Continuous, Discrete
from pymc.distributions.distribution import Continuous, Discrete, get_moment
from pymc.distributions.shape_utils import (
broadcast_dist_samples_to,
rv_size_is_none,
Expand Down Expand Up @@ -558,11 +558,7 @@ def dist(cls, n, p, *args, **kwargs):
return super().dist([n, p], *args, **kwargs)

def get_moment(rv, size, n, p):
if p.ndim > 1:
n = at.shape_padright(n)
if (p.ndim == 1) & (n.ndim > 0):
n = at.shape_padright(n)
p = at.shape_padleft(p)
n = at.shape_padright(n)
mode = at.round(n * p)
diff = n - at.sum(mode, axis=-1, keepdims=True)
inc_bool_arr = at.abs_(diff) > 0
Expand Down Expand Up @@ -682,21 +678,8 @@ def dist(cls, n, a, *args, **kwargs):
return super().dist([n, a], **kwargs)

def get_moment(rv, size, n, a):
p = a / at.sum(a, axis=-1)
mode = at.round(n * p)
diff = n - at.sum(mode, axis=-1, keepdims=True)
inc_bool_arr = at.abs_(diff) > 0
mode = at.inc_subtensor(mode[inc_bool_arr.nonzero()], diff[inc_bool_arr.nonzero()])

# Reshape mode according to dimensions implied by the parameters
# This can include axes of length 1
_, p_bcast = broadcast_params([n, p], ndims_params=[0, 1])
mode = at.reshape(mode, p_bcast.shape)

if not rv_size_is_none(size):
output_size = at.concatenate([size, [p.shape[-1]]])
mode = at.full(output_size, mode)
return mode
p = a / at.sum(a, axis=-1, keepdims=True)
return get_moment(Multinomial.dist(n=n, p=p, size=size))

def logp(value, n, a):
"""
Expand Down
26 changes: 16 additions & 10 deletions pymc/tests/test_distributions_moments.py
Original file line number Diff line number Diff line change
Expand Up @@ -1308,22 +1308,22 @@ def test_polyagamma_moment(h, z, size, expected):
np.array([[4, 6, 0, 0], [4, 2, 2, 2]]),
),
(
np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]),
np.array([1, 10]),
None,
np.array([[1, 0, 0, 0], [2, 3, 3, 2]]),
np.array([0.3, 0.6, 0.05, 0.05]),
np.array([2, 10]),
(1, 2),
np.array([[[1, 1, 0, 0], [4, 6, 0, 0]]]),
),
(
np.array([0.26, 0.26, 0.26, 0.22]),
np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]),
np.array([1, 10]),
None,
np.array([[1, 0, 0, 0], [2, 3, 3, 2]]),
),
(
np.array([[0.25, 0.25, 0.25, 0.25], [0.26, 0.26, 0.26, 0.22]]),
np.array([1, 10]),
(2, 2),
np.full((2, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]),
(3, 2),
np.full((3, 2, 4), [[1, 0, 0, 0], [2, 3, 3, 2]]),
),
],
)
Expand Down Expand Up @@ -1470,10 +1470,16 @@ def test_lkjcholeskycov_moment(n, eta, size, expected):
(np.array([3, 6, 0.5, 0.5]), 2, None, np.array([1, 1, 0, 0])),
(np.array([30, 60, 5, 5]), 10, None, np.array([4, 6, 0, 0])),
(
np.array([[26, 26, 26, 22]]), # Dim: 1 x 4
np.array([[1], [10]]), # Dim: 2 x 1
np.array([[30, 60, 5, 5], [26, 26, 26, 22]]),
10,
(1, 2),
np.array([[[4, 6, 0, 0], [2, 3, 3, 2]]]),
),
(
np.array([26, 26, 26, 22]),
np.array([1, 10]),
None,
np.array([[[1, 0, 0, 0]], [[2, 3, 3, 2]]]), # Dim: 2 x 1 x 4
np.array([[1, 0, 0, 0], [2, 3, 3, 2]]),
),
(
np.array([[26, 26, 26, 22]]), # Dim: 1 x 4
Expand Down

0 comments on commit a0cff37

Please sign in to comment.