diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 7ef7dde2f74..b15c42f6243 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -24,7 +24,7 @@ from aesara.tensor.random.op import RandomVariable from pymc.aesaraf import take_along_axis -from pymc.distributions.continuous import Normal +from pymc.distributions.continuous import Normal, get_tau_sigma from pymc.distributions.dist_math import check_parameters from pymc.distributions.distribution import Discrete, Distribution, SymbolicDistribution from pymc.distributions.logprob import logp @@ -395,7 +395,7 @@ def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs): return mix_logp -class NormalMixture(Mixture): +class NormalMixture: R""" Normal mixture log-likelihood @@ -450,18 +450,20 @@ class NormalMixture(Mixture): pm.NormalMixture("y", w=weights, mu=μ, sigma=σ, observed=data) """ - def __init__(self, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), *args, **kwargs): + def __new__(cls, name, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), **kwargs): if sd is not None: sigma = sd _, sigma = get_tau_sigma(tau=tau, sigma=sigma) - self.mu = mu = at.as_tensor_variable(mu) - self.sigma = self.sd = sigma = at.as_tensor_variable(sigma) + return Mixture(name, w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs) - super().__init__(w, Normal.dist(mu, sigma=sigma, shape=comp_shape), *args, **kwargs) + @classmethod + def dist(cls, w, mu, sigma=None, tau=None, sd=None, comp_shape=(), **kwargs): + if sd is not None: + sigma = sd + _, sigma = get_tau_sigma(tau=tau, sigma=sigma) - def _distr_parameters_for_repr(self): - return ["w", "mu", "sigma"] + return Mixture.dist(w, Normal.dist(mu, sigma=sigma, size=comp_shape), **kwargs) class MixtureSameFamily(Distribution): diff --git a/pymc/tests/test_distributions_random.py b/pymc/tests/test_distributions_random.py index 866053d625f..6ab96736347 100644 --- a/pymc/tests/test_distributions_random.py +++ b/pymc/tests/test_distributions_random.py @@ -73,6 +73,7 @@ def pymc_random( fails=10, extra_args=None, model_args=None, + change_rv_size_fn=change_rv_size, ): if valuedomain is None: valuedomain = Domain([0], edges=(None, None)) @@ -81,7 +82,7 @@ def pymc_random( model_args = {} model, param_vars = build_model(dist, valuedomain, paramdomains, extra_args) - model_dist = change_rv_size(model.named_vars["value"], size, expand=True) + model_dist = change_rv_size_fn(model.named_vars["value"], size, expand=True) pymc_rand = aesara.function([], model_dist) domains = paramdomains.copy() diff --git a/pymc/tests/test_mixture.py b/pymc/tests/test_mixture.py index a861350c495..72b3cd822f1 100644 --- a/pymc/tests/test_mixture.py +++ b/pymc/tests/test_mixture.py @@ -577,28 +577,33 @@ def mixmixlogp(value, point): assert_allclose(priorlogp + mixmixlogpg.sum(), model.logp(test_point), rtol=rtol) -@pytest.mark.xfail(reason="NormalMixture not refactored yet") class TestNormalMixture(SeededTest): - @classmethod - def setup_class(cls): - TestMixture.setup_class() + def test_normal_mixture_sampling(self): + norm_w = np.array([0.75, 0.25]) + norm_mu = np.array([0.0, 5.0]) + norm_sd = np.ones_like(norm_mu) + norm_x = generate_normal_mixture_data(norm_w, norm_mu, norm_sd, size=1000) - def test_normal_mixture(self): with Model() as model: - w = Dirichlet("w", floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size) - mu = Normal("mu", 0.0, 10.0, shape=self.norm_w.size) - tau = Gamma("tau", 1.0, 1.0, shape=self.norm_w.size) - NormalMixture("x_obs", w, mu, tau=tau, observed=self.norm_x) + w = Dirichlet("w", floatX(np.ones_like(norm_w)), shape=norm_w.size) + mu = Normal("mu", 0.0, 10.0, shape=norm_w.size) + tau = Gamma("tau", 1.0, 1.0, shape=norm_w.size) + NormalMixture("x_obs", w, mu, tau=tau, observed=norm_x) step = Metropolis() - trace = sample(5000, step, random_seed=self.random_seed, progressbar=False, chains=1) + trace = sample( + 5000, + step, + random_seed=self.random_seed, + progressbar=False, + chains=1, + return_inferencedata=False, + ) - assert_allclose(np.sort(trace["w"].mean(axis=0)), np.sort(self.norm_w), rtol=0.1, atol=0.1) - assert_allclose( - np.sort(trace["mu"].mean(axis=0)), np.sort(self.norm_mu), rtol=0.1, atol=0.1 - ) + assert_allclose(np.sort(trace["w"].mean(axis=0)), np.sort(norm_w), rtol=0.1, atol=0.1) + assert_allclose(np.sort(trace["mu"].mean(axis=0)), np.sort(norm_mu), rtol=0.1, atol=0.1) @pytest.mark.parametrize( - "nd,ncomp", [(tuple(), 5), (1, 5), (3, 5), ((3, 3), 5), (3, 3), ((3, 3), 3)], ids=str + "nd, ncomp", [(tuple(), 5), (1, 5), (3, 5), ((3, 3), 5), (3, 3), ((3, 3), 3)], ids=str ) def test_normal_mixture_nd(self, nd, ncomp): nd = to_tuple(nd) @@ -616,7 +621,7 @@ def test_normal_mixture_nd(self, nd, ncomp): ws = Dirichlet("ws", np.ones(ncomp), shape=(ncomp,)) mixture0 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape) obs0 = NormalMixture( - "obs", w=ws, mu=mus, tau=taus, shape=nd, comp_shape=comp_shape, observed=observed + "obs", w=ws, mu=mus, tau=taus, comp_shape=comp_shape, observed=observed ) with Model() as model1: @@ -627,53 +632,27 @@ def test_normal_mixture_nd(self, nd, ncomp): Normal.dist(mu=mus[..., i], tau=taus[..., i], shape=nd) for i in range(ncomp) ] mixture1 = Mixture("m", w=ws, comp_dists=comp_dist, shape=nd) - obs1 = Mixture("obs", w=ws, comp_dists=comp_dist, shape=nd, observed=observed) + obs1 = Mixture("obs", w=ws, comp_dists=comp_dist, observed=observed) with Model() as model2: - # Expected to fail if comp_shape is not provided, - # nd is multidim and it does not broadcast with ncomp. If by chance - # it does broadcast, an error is raised if the mixture is given - # observed data. - # Furthermore, the Mixture will also raise errors when the observed - # data is multidimensional but it does not broadcast well with - # comp_dists. + # Test that results are correct without comp_shape being passed to the Mixture. + # This used to fail in V3 mus = Normal("mus", shape=comp_shape) taus = Gamma("taus", alpha=1, beta=1, shape=comp_shape) ws = Dirichlet("ws", np.ones(ncomp), shape=(ncomp,)) - if len(nd) > 1: - if nd[-1] != ncomp: - with pytest.raises(ValueError): - NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd) - mixture2 = None - else: - mixture2 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd) - else: - mixture2 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd) - observed_fails = False - if len(nd) >= 1 and nd != (1,): - try: - np.broadcast(np.empty(comp_shape), observed) - except Exception: - observed_fails = True - if observed_fails: - with pytest.raises(ValueError): - NormalMixture("obs", w=ws, mu=mus, tau=taus, shape=nd, observed=observed) - obs2 = None - else: - obs2 = NormalMixture("obs", w=ws, mu=mus, tau=taus, shape=nd, observed=observed) + mixture2 = NormalMixture("m", w=ws, mu=mus, tau=taus, shape=nd) + obs2 = NormalMixture("obs", w=ws, mu=mus, tau=taus, observed=observed) testpoint = model0.compute_initial_point() testpoint["mus"] = test_mus - testpoint["taus"] = test_taus - assert_allclose(model0.logp(testpoint), model1.logp(testpoint)) - assert_allclose(mixture0.logp(testpoint), mixture1.logp(testpoint)) - assert_allclose(obs0.logp(testpoint), obs1.logp(testpoint)) - if mixture2 is not None and obs2 is not None: - assert_allclose(model0.logp(testpoint), model2.logp(testpoint)) - if mixture2 is not None: - assert_allclose(mixture0.logp(testpoint), mixture2.logp(testpoint)) - if obs2 is not None: - assert_allclose(obs0.logp(testpoint), obs2.logp(testpoint)) + testpoint["taus_log__"] = np.log(test_taus) + for logp0, logp1, logp2 in zip( + model0.compile_logp(vars=[mixture0, obs0], sum=False)(testpoint), + model1.compile_logp(vars=[mixture1, obs1], sum=False)(testpoint), + model2.compile_logp(vars=[mixture2, obs2], sum=False)(testpoint), + ): + assert_allclose(logp0, logp1) + assert_allclose(logp0, logp2) def test_random(self): def ref_rand(size, w, mu, sigma): @@ -690,6 +669,7 @@ def ref_rand(size, w, mu, sigma): extra_args={"comp_shape": 2}, size=1000, ref_rand=ref_rand, + change_rv_size_fn=Mixture.change_size, ) pymc_random( NormalMixture, @@ -701,6 +681,7 @@ def ref_rand(size, w, mu, sigma): extra_args={"comp_shape": 3}, size=1000, ref_rand=ref_rand, + change_rv_size_fn=Mixture.change_size, )