Skip to content

Commit

Permalink
Remove Dirichlet distribution type restrictions (#4000)
Browse files Browse the repository at this point in the history
* Remove Dirichlet distribution type restrictions

Closes #3999.

* Add missing Dirichlet shape parameters to tests

* Remove Dirichlet positive concentration parameter constructor tests

This test can't be performed in the constructor if we're allowing Theano-type
distribution parameters.

* Add a hack to statically infer Dirichlet argument shapes

Co-authored-by: Brandon T. Willard <[email protected]>
  • Loading branch information
brandonwillard and brandonwillard authored Jul 21, 2020
1 parent b2c682e commit f07c273
Show file tree
Hide file tree
Showing 5 changed files with 41 additions and 43 deletions.
29 changes: 15 additions & 14 deletions pymc3/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@

from scipy import stats, linalg

from theano.gof.op import get_test_value
from theano.tensor.nlinalg import det, matrix_inverse, trace, eigh
from theano.tensor.slinalg import Cholesky
import pymc3 as pm
Expand Down Expand Up @@ -487,22 +488,23 @@ class Dirichlet(Continuous):
def __init__(self, a, transform=transforms.stick_breaking,
*args, **kwargs):

if not isinstance(a, pm.model.TensorVariable):
if not isinstance(a, list) and not isinstance(a, np.ndarray):
raise TypeError(
'The vector of concentration parameters (a) must be a python list '
'or numpy array.')
a = np.array(a)
if (a <= 0).any():
raise ValueError("All concentration parameters (a) must be > 0.")

shape = np.atleast_1d(a.shape)[-1]
if kwargs.get('shape') is None:
warnings.warn(
(
"Shape not explicitly set. "
"Please, set the value using the `shape` keyword argument. "
"Using the test value to infer the shape."
),
DeprecationWarning
)
try:
kwargs['shape'] = get_test_value(tt.shape(a))
except AttributeError:
pass

kwargs.setdefault("shape", shape)
super().__init__(transform=transform, *args, **kwargs)

self.size_prefix = tuple(self.shape[:-1])
self.k = tt.as_tensor_variable(shape)
self.a = a = tt.as_tensor_variable(a)
self.mean = a / tt.sum(a)

Expand Down Expand Up @@ -569,14 +571,13 @@ def logp(self, value):
-------
TensorVariable
"""
k = self.k
a = self.a

# only defined for sum(value) == 1
return bound(tt.sum(logpow(value, a - 1) - gammaln(a), axis=-1)
+ gammaln(tt.sum(a, axis=-1)),
tt.all(value >= 0), tt.all(value <= 1),
k > 1, tt.all(a > 0),
np.logical_not(a.broadcastable), tt.all(a > 0),
broadcast_conditions=False)

def _repr_latex_(self, name=None, dist=None):
Expand Down
4 changes: 2 additions & 2 deletions pymc3/tests/test_dist_math.py
Original file line number Diff line number Diff line change
Expand Up @@ -126,11 +126,11 @@ def test_multinomial_bound():
n = x.sum()

with pm.Model() as modelA:
p_a = pm.Dirichlet('p', floatX(np.ones(2)))
p_a = pm.Dirichlet('p', floatX(np.ones(2)), shape=(2,))
MultinomialA('x', n, p_a, observed=x)

with pm.Model() as modelB:
p_b = pm.Dirichlet('p', floatX(np.ones(2)))
p_b = pm.Dirichlet('p', floatX(np.ones(2)), shape=(2,))
MultinomialB('x', n, p_b, observed=x)

assert np.isclose(modelA.logp({'p_stickbreaking__': [0]}),
Expand Down
19 changes: 8 additions & 11 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -1328,17 +1328,14 @@ def test_dirichlet(self, n):
Dirichlet, Simplex(n), {"a": Vector(Rplus, n)}, dirichlet_logpdf
)

@pytest.mark.parametrize("n", [3, 4])
def test_dirichlet_init_fail(self, n):
with Model():
with pytest.raises(
ValueError, match=r"All concentration parameters \(a\) must be > 0."
):
_ = Dirichlet("x", a=np.zeros(n), shape=n)
with pytest.raises(
ValueError, match=r"All concentration parameters \(a\) must be > 0."
):
_ = Dirichlet("x", a=np.array([-1.0] * n), shape=n)
def test_dirichlet_shape(self):
a = tt.as_tensor_variable(np.r_[1, 2])
with pytest.warns(DeprecationWarning):
dir_rv = Dirichlet.dist(a)
assert dir_rv.shape == (2,)

with pytest.warns(DeprecationWarning), theano.change_flags(compute_test_value="ignore"):
dir_rv = Dirichlet.dist(tt.vector())

def test_dirichlet_2D(self):
self.pymc3_matches_scipy(
Expand Down
8 changes: 4 additions & 4 deletions pymc3/tests/test_distributions_random.py
Original file line number Diff line number Diff line change
Expand Up @@ -912,15 +912,15 @@ def test_mixture_random_shape():
nr.poisson(9, size=10)])
with pm.Model() as m:
comp0 = pm.Poisson.dist(mu=np.ones(2))
w0 = pm.Dirichlet('w0', a=np.ones(2))
w0 = pm.Dirichlet('w0', a=np.ones(2), shape=(2,))
like0 = pm.Mixture('like0',
w=w0,
comp_dists=comp0,
observed=y)

comp1 = pm.Poisson.dist(mu=np.ones((20, 2)),
shape=(20, 2))
w1 = pm.Dirichlet('w1', a=np.ones(2))
w1 = pm.Dirichlet('w1', a=np.ones(2), shape=(2,))
like1 = pm.Mixture('like1',
w=w1,
comp_dists=comp1,
Expand Down Expand Up @@ -967,15 +967,15 @@ def test_mixture_random_shape_fast():
nr.poisson(9, size=10)])
with pm.Model() as m:
comp0 = pm.Poisson.dist(mu=np.ones(2))
w0 = pm.Dirichlet('w0', a=np.ones(2))
w0 = pm.Dirichlet('w0', a=np.ones(2), shape=(2,))
like0 = pm.Mixture('like0',
w=w0,
comp_dists=comp0,
observed=y)

comp1 = pm.Poisson.dist(mu=np.ones((20, 2)),
shape=(20, 2))
w1 = pm.Dirichlet('w1', a=np.ones(2))
w1 = pm.Dirichlet('w1', a=np.ones(2), shape=(2,))
like1 = pm.Mixture('like1',
w=w1,
comp_dists=comp1,
Expand Down
24 changes: 12 additions & 12 deletions pymc3/tests/test_mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,7 @@ def test_dimensions(self):

def test_mixture_list_of_normals(self):
with Model() as model:
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size)
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
Mixture('x_obs', w,
Expand All @@ -98,7 +98,7 @@ def test_mixture_list_of_normals(self):

def test_normal_mixture(self):
with Model() as model:
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)))
w = Dirichlet('w', floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size)
mu = Normal('mu', 0., 10., shape=self.norm_w.size)
tau = Gamma('tau', 1., 1., shape=self.norm_w.size)
NormalMixture('x_obs', w, mu, tau=tau, observed=self.norm_x)
Expand Down Expand Up @@ -135,7 +135,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
with Model() as model0:
mus = Normal('mus', shape=comp_shape)
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
ws = Dirichlet('ws', np.ones(ncomp))
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
mixture0 = NormalMixture('m', w=ws, mu=mus, tau=taus, shape=nd,
comp_shape=comp_shape)
obs0 = NormalMixture('obs', w=ws, mu=mus, tau=taus, shape=nd,
Expand All @@ -145,7 +145,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
with Model() as model1:
mus = Normal('mus', shape=comp_shape)
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
ws = Dirichlet('ws', np.ones(ncomp))
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
comp_dist = [Normal.dist(mu=mus[..., i], tau=taus[..., i],
shape=nd)
for i in range(ncomp)]
Expand All @@ -163,7 +163,7 @@ def test_normal_mixture_nd(self, nd, ncomp):
# comp_dists.
mus = Normal('mus', shape=comp_shape)
taus = Gamma('taus', alpha=1, beta=1, shape=comp_shape)
ws = Dirichlet('ws', np.ones(ncomp))
ws = Dirichlet('ws', np.ones(ncomp), shape=(ncomp,))
if len(nd) > 1:
if nd[-1] != ncomp:
with pytest.raises(ValueError):
Expand Down Expand Up @@ -208,7 +208,7 @@ def test_normal_mixture_nd(self, nd, ncomp):

def test_poisson_mixture(self):
with Model() as model:
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)))
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)), shape=self.pois_w.shape)
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
Mixture('x_obs', w, Poisson.dist(mu), observed=self.pois_x)
step = Metropolis()
Expand All @@ -224,7 +224,7 @@ def test_poisson_mixture(self):

def test_mixture_list_of_poissons(self):
with Model() as model:
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)))
w = Dirichlet('w', floatX(np.ones_like(self.pois_w)), shape=self.pois_w.shape)
mu = Gamma('mu', 1., 1., shape=self.pois_w.size)
Mixture('x_obs', w,
[Poisson.dist(mu[0]), Poisson.dist(mu[1])],
Expand All @@ -247,7 +247,7 @@ def test_mixture_of_mvn(self):
cov2 = np.diag([2.5, 3.5])
obs = np.asarray([[.5, .5], mu1, mu2])
with Model() as model:
w = Dirichlet('w', floatX(np.ones(2)), transform=None)
w = Dirichlet('w', floatX(np.ones(2)), transform=None, shape=(2,))
mvncomp1 = MvNormal.dist(mu=mu1, cov=cov1)
mvncomp2 = MvNormal.dist(mu=mu2, cov=cov2)
y = Mixture('x_obs', w, [mvncomp1, mvncomp2],
Expand Down Expand Up @@ -291,13 +291,13 @@ def test_mixture_of_mixture(self):
sigma=1,
shape=nbr)
# weight vector for the mixtures
g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None)
l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None)
g_w = Dirichlet('g_w', a=floatX(np.ones(nbr)*0.0000001), transform=None, shape=(nbr,))
l_w = Dirichlet('l_w', a=floatX(np.ones(nbr)*0.0000001), transform=None, shape=(nbr,))
# mixture components
g_mix = Mixture.dist(w=g_w, comp_dists=g_comp)
l_mix = Mixture.dist(w=l_w, comp_dists=l_comp)
# mixture of mixtures
mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None)
mix_w = Dirichlet('mix_w', a=floatX(np.ones(2)), transform=None, shape=(2,))
mix = Mixture('mix', w=mix_w,
comp_dists=[g_mix, l_mix],
observed=np.exp(self.norm_x))
Expand Down Expand Up @@ -378,7 +378,7 @@ def build_toy_dataset(N, K):
X, y = build_toy_dataset(N, K)

with pm.Model() as model:
pi = pm.Dirichlet('pi', np.ones(K))
pi = pm.Dirichlet('pi', np.ones(K), shape=(K,))

comp_dist = []
mu = []
Expand Down

0 comments on commit f07c273

Please sign in to comment.