From ce37fd45b2916cf14fad0e2678fea435718b0051 Mon Sep 17 00:00:00 2001 From: Ricardo Date: Tue, 1 Feb 2022 15:54:25 +0100 Subject: [PATCH] Refactor Mixture distribution for V4 --- pymc/distributions/mixture.py | 604 +++++++---------------- pymc/tests/test_distributions_moments.py | 2 - pymc/tests/test_mixture.py | 134 ++++- 3 files changed, 291 insertions(+), 449 deletions(-) diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 6105bb726ca..327251b5c35 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -11,19 +11,32 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. - -from collections.abc import Iterable +import warnings import aesara import aesara.tensor as at import numpy as np -from pymc.aesaraf import _conversion_map, take_along_axis -from pymc.distributions.continuous import Normal, get_tau_sigma +from aeppl.abstract import MeasurableVariable, _get_measurable_outputs +from aeppl.logprob import _logprob +from aesara.compile.builders import OpFromGraph +from aesara.tensor import TensorVariable +from aesara.tensor.random.op import RandomVariable + +from pymc.aesaraf import take_along_axis +from pymc.distributions.continuous import Normal from pymc.distributions.dist_math import check_parameters -from pymc.distributions.distribution import Discrete, Distribution -from pymc.distributions.shape_utils import to_tuple +from pymc.distributions.distribution import ( + Discrete, + Distribution, + SymbolicDistribution, + _get_moment, + get_moment, +) +from pymc.distributions.logprob import logp +from pymc.distributions.shape_utils import ShapeWarning, to_tuple from pymc.math import logsumexp +from pymc.util import check_dist_not_registered __all__ = ["Mixture", "NormalMixture", "MixtureSameFamily"] @@ -38,7 +51,14 @@ def all_discrete(comp_dists): return all(isinstance(comp_dist, Discrete) for comp_dist in comp_dists) -class Mixture(Distribution): +class MarginalMixtureRV(OpFromGraph): + """A placeholder used to specify a log-likelihood for a mixture sub-graph.""" + + +MeasurableVariable.register(MarginalMixtureRV) + + +class Mixture(SymbolicDistribution): R""" Mixture log-likelihood @@ -112,454 +132,178 @@ class Mixture(Distribution): like = pm.Mixture('like', w=w, comp_dists = components, observed=data, shape=3) """ - def __init__(self, w, comp_dists, *args, **kwargs): - # comp_dists type checking - if not ( - isinstance(comp_dists, Distribution) - or ( - isinstance(comp_dists, Iterable) - and all(isinstance(c, Distribution) for c in comp_dists) - ) - ): - raise TypeError( - "Supplied Mixture comp_dists must be a " - "Distribution or an iterable of " - "Distributions. Got {} instead.".format( - type(comp_dists) - if not isinstance(comp_dists, Iterable) - else [type(c) for c in comp_dists] + @classmethod + def dist(cls, w, comp_dists, **kwargs): + # TODO: Reintroduce support for single variable comp_dists + + # Check that components are not associated with a registered variable in the model + components_ndim = set() + components_ndim_supp = set() + for dist in comp_dists: + if not isinstance(dist, TensorVariable) or not isinstance( + dist.owner.op, RandomVariable + ): + raise ValueError( + f"Component dist must be a distribution created via the `.dist()` API, got {type(dist)}" ) - ) - shape = kwargs.pop("shape", ()) + check_dist_not_registered(dist) + components_ndim.add(dist.ndim) + components_ndim_supp.add(dist.owner.op.ndim_supp) - self.w = w = at.as_tensor_variable(w) - self.comp_dists = comp_dists + if len(components_ndim) > 1: + raise ValueError( + f"Mixture components must all have the same dimensionality, got {components_ndim}" + ) - defaults = kwargs.pop("defaults", []) + if len(components_ndim_supp) > 1: + raise ValueError( + f"Mixture components must all have the same support dimensionality, got {components_ndim_supp}" + ) - if all_discrete(comp_dists): - default_dtype = _conversion_map[aesara.config.floatX] + w = at.as_tensor_variable(w) + + # ShapeWarning does not make sense here + with warnings.catch_warnings(): + warnings.simplefilter("ignore", ShapeWarning) + return super().dist([w, *comp_dists], **kwargs) + + @classmethod + def rv_op(cls, weights, *components, size=None, rngs=None): + # Update rngs if provided + if rngs is not None: + *components_rngs, choices_rng = rngs + new_components = [] + assert len(components) == len(components_rngs) + for component, component_rng in zip(components, components_rngs): + component_node = component.owner + old_rng, *inputs = component_node.inputs + new_components.append( + component_node.op.make_node(component_rng, *inputs).default_output() + ) + components = new_components else: - default_dtype = aesara.config.floatX - - try: - self.mean = (w * self._comp_means()).sum(axis=-1) - - if "mean" not in defaults: - defaults.append("mean") - except AttributeError: - pass - dtype = kwargs.pop("dtype", default_dtype) - - try: - if isinstance(comp_dists, Distribution): - comp_mode_logps = comp_dists.logp(comp_dists.mode) - else: - comp_mode_logps = at.stack([cd.logp(cd.mode) for cd in comp_dists]) + # Create new rng for the choices internal RV + choices_rng = aesara.shared(np.random.default_rng()) - mode_idx = at.argmax(at.log(w) + comp_mode_logps, axis=-1) - self.mode = self._comp_modes()[mode_idx] + # Create a OpFromGraph that encapsulates the random generating process + weights_type = weights.type() + components_type = [component.type() for component in components] + rng_type = choices_rng.type() - if "mode" not in defaults: - defaults.append("mode") - except (AttributeError, ValueError, IndexError): - pass - - super().__init__(shape, dtype, defaults=defaults, *args, **kwargs) - - @property - def comp_dists(self): - return self._comp_dists - - @comp_dists.setter - def comp_dists(self, comp_dists): - self._comp_dists = comp_dists - if isinstance(comp_dists, Distribution): - self._comp_dist_shapes = to_tuple(comp_dists.shape) - self._broadcast_shape = self._comp_dist_shapes - self.comp_is_distribution = True + stacked_components_ = at.stack(components_type, axis=-1) + if weights.ndim < stacked_components_.ndim: + weights_ = at.shape_padaxis(weights_type, axis=weights_type.ndim - 1) else: - # Now we check the comp_dists distribution shape, see what - # the broadcast shape would be. This shape will be the dist_shape - # used by generate samples (the shape of a single random sample) - # from the mixture - self._comp_dist_shapes = [to_tuple(d.shape) for d in comp_dists] - # All component distributions must broadcast with each other - try: - self._broadcast_shape = np.broadcast( - *(np.empty(shape) for shape in self._comp_dist_shapes) - ).shape - except Exception: - raise TypeError( - "Supplied comp_dists shapes do not broadcast " - "with each other. comp_dists shapes are: " - "{}".format(self._comp_dist_shapes) - ) + weights_ = weights_type + broadcasted_weights_ = at.broadcast_to(weights_, stacked_components_.shape) - # We wrap the _comp_dist.random by adding the kwarg raw_size_, - # which will be the size attribute passed to _comp_samples. - # _comp_samples then calls generate_samples, which may change the - # size value to make it compatible with scipy.stats.*.rvs - self._generators = [] - for comp_dist in comp_dists: - generator = Mixture._comp_dist_random_wrapper(comp_dist.random) - self._generators.append(generator) - self.comp_is_distribution = False - - @staticmethod - def _comp_dist_random_wrapper(random): - """Wrap the comp_dists.random method to take the kwarg raw_size_ and - use it's value to replace the size parameter. This is needed because - generate_samples makes the size value compatible with the - scipy.stats.*.rvs, where size has a different meaning than in the - distributions' random methods. - """ + choices_ = at.random.categorical(broadcasted_weights_, rng=rng_type) + mix_out_ = at.take_along_axis(stacked_components_, choices_[..., None], axis=-1) + mix_out_ = at.reshape(mix_out_, components_type[0].shape) + + next_rng_ = choices_.owner.outputs[0] + mix_op = MarginalMixtureRV( + inputs=[rng_type, weights_type, *components_type], outputs=[next_rng_, mix_out_] + ) - def wrapped_random(*args, **kwargs): - raw_size_ = kwargs.pop("raw_size_", None) - # Distribution.random's signature is always (point=None, size=None) - # so size could be the second arg or be given as a kwarg - if len(args) > 1: - args[1] = raw_size_ + # Create the actual MarginalMixture variable + next_rng, mix_out = mix_op(choices_rng, weights, *components) + + # We need to set_default_updates ourselves, because the choices RV is hidden + # inside OpFromGraph and PyMC will never find it otherwise + choices_rng.default_update = next_rng + + # Reference nodes to facilitate identification in other classmethods + mix_out.tag.weights = weights + mix_out.tag.components = components + mix_out.tag.choices_rng = choices_rng + + # Component RVs terms are accounted by the Mixture logprob, so they can be + # safely ignore by Aeppl (this tag prevents UserWarning) + for component in components: + component.tag.ignore_logprob = True + + if size is not None: + mix_out = cls.change_size(mix_out, size) + + return mix_out + + @classmethod + def ndim_supp(cls, rng, weights, *components): + # We already checked that all components have the same dimensionality + return components[0].ndim + + @classmethod + def change_size(cls, rv, new_size): + component_nodes = [node.owner for node in rv.tag.components] + new_components = [] + rngs = [] + for component_node in component_nodes: + rng, old_size, dtype, *dist_params = component_node.inputs + old_size_len = -at.get_vector_length(old_size) + # Avoid issue with [:-0] slice + if old_size_len: + extended_size = at.concatenate([to_tuple(new_size)[:old_size_len], old_size]) else: - kwargs["size"] = raw_size_ - return random(*args, **kwargs) + extended_size = to_tuple(new_size) + new_components.append( + component_node.op.make_node( + rng, extended_size, dtype, *dist_params + ).default_output() + ) + rngs.append(rng) - return wrapped_random + rngs.append(rv.tag.choices_rng) - def _comp_logp(self, value): - comp_dists = self.comp_dists + weights = rv.tag.weights + return cls.rv_op(weights, *new_components, rngs=rngs) - if self.comp_is_distribution: - # Value can be many things. It can be the self tensor, the mode - # test point or it can be observed data. The latter case requires - # careful handling of shape, as the observed's shape could look - # like (repetitions,) + dist_shape, which does not include the last - # mixture axis. For this reason, we try to eval the value.shape, - # compare it with self.shape and shape_padright if we infer that - # the value holds observed data - try: - val_shape = tuple(value.shape.eval()) - except AttributeError: - val_shape = value.shape - except aesara.graph.fg.MissingInputError: - val_shape = None - try: - self_shape = tuple(self.shape) - except AttributeError: - # Happens in __init__ when computing self.logp(comp_modes) - self_shape = None - comp_shape = tuple(comp_dists.shape) - ndim = value.ndim - if val_shape is not None and not ( - (self_shape is not None and val_shape == self_shape) or val_shape == comp_shape - ): - # value is neither the test point nor the self tensor, it - # is likely to hold observed values, so we must compute the - # ndim discarding the dimensions that don't match - # self_shape - if self_shape and val_shape[-len(self_shape) :] == self_shape: - # value has observed values for the Mixture - ndim = len(self_shape) - elif comp_shape and val_shape[-len(comp_shape) :] == comp_shape: - # value has observed for the Mixture components - ndim = len(comp_shape) - else: - # We cannot infer what was passed, we handle this - # as was done in earlier versions of Mixture. We pad - # always if ndim is lower or equal to 1 (default - # legacy implementation) - if ndim <= 1: - ndim = len(comp_dists.shape) - 1 - else: - # We reach this point if value does not hold observed data, so - # we can use its ndim safely to determine shape padding, or it - # holds something that we cannot infer, so we revert to using - # the value's ndim for shape padding. - # We will always pad a single dimension if ndim is lower or - # equal to 1 (default legacy implementation) - if ndim <= 1: - ndim = len(comp_dists.shape) - 1 - if ndim < len(comp_dists.shape): - value_ = at.shape_padright(value, len(comp_dists.shape) - ndim) - else: - value_ = value - return comp_dists.logp(value_) - else: - return at.squeeze( - at.stack([comp_dist.logp(value) for comp_dist in comp_dists], axis=-1) - ) + @classmethod + def graph_rvs(cls, rv): + # We return rv, which is a pseudo RandomVariable, that contains a choices RV + # in its inner graph. We want super().dist() to generate components + 1 rngs for + # us, and it will do so based on how many elements we return here + return (*rv.tag.components, rv) - def _comp_means(self): - try: - return at.as_tensor_variable(self.comp_dists.mean) - except AttributeError: - return at.squeeze(at.stack([comp_dist.mean for comp_dist in self.comp_dists], axis=-1)) - - def _comp_modes(self): - try: - return at.as_tensor_variable(self.comp_dists.mode) - except AttributeError: - return at.squeeze(at.stack([comp_dist.mode for comp_dist in self.comp_dists], axis=-1)) - - def _comp_samples(self, point=None, size=None, comp_dist_shapes=None, broadcast_shape=None): - # if self.comp_is_distribution: - # samples = self._comp_dists.random(point=point, size=size) - # else: - # if comp_dist_shapes is None: - # comp_dist_shapes = self._comp_dist_shapes - # if broadcast_shape is None: - # broadcast_shape = self._sample_shape - # samples = [] - # for dist_shape, generator in zip(comp_dist_shapes, self._generators): - # sample = generate_samples( - # generator=generator, - # dist_shape=dist_shape, - # broadcast_shape=broadcast_shape, - # point=point, - # size=size, - # not_broadcast_kwargs={"raw_size_": size}, - # ) - # samples.append(sample) - # samples = np.array(broadcast_distribution_samples(samples, size=size)) - # # In the logp we assume the last axis holds the mixture components - # # so we move the axis to the last dimension - # samples = np.moveaxis(samples, 0, -1) - # return samples.astype(self.dtype) - pass - - def infer_comp_dist_shapes(self, point=None): - """Try to infer the shapes of the component distributions, - `comp_dists`, and how they should broadcast together. - The behavior is slightly different if `comp_dists` is a `Distribution` - as compared to when it is a list of `Distribution`s. When it is a list - the following procedure is repeated for each element in the list: - 1. Look up the `comp_dists.shape` - 2. If it is not empty, use it as `comp_dist_shape` - 3. If it is an empty tuple, a single random sample is drawn by calling - `comp_dists.random(point=point, size=None)`, and the returned - test_sample's shape is used as the inferred `comp_dists.shape` - Parameters - ---------- - point: None or dict (optional) - Dictionary that maps rv names to values, to supply to - `self.comp_dists.random` +@_get_measurable_outputs.register(MarginalMixtureRV) +def _get_measurable_outputs_MarginalMixtureRV(op, node): + # This tells Aeppl that the second output is the measurable one + return [node.outputs[1]] - Returns - ------- - comp_dist_shapes: shape tuple or list of shape tuples. - If `comp_dists` is a `Distribution`, it is a shape tuple of the - inferred distribution shape. - If `comp_dists` is a list of `Distribution`s, it is a list of - shape tuples inferred for each element in `comp_dists` - broadcast_shape: shape tuple - The shape that results from broadcasting all component's shapes - together. - """ - # if self.comp_is_distribution: - # if len(self._comp_dist_shapes) > 0: - # comp_dist_shapes = self._comp_dist_shapes - # else: - # # Happens when the distribution is a scalar or when it was not - # # given a shape. In these cases we try to draw a single value - # # to check its shape, we use the provided point dictionary - # # hoping that it can circumvent the Flat and HalfFlat - # # undrawable distributions. - # with _DrawValuesContextBlocker(): - # test_sample = self._comp_dists.random(point=point, size=None) - # comp_dist_shapes = test_sample.shape - # broadcast_shape = comp_dist_shapes - # else: - # # Now we check the comp_dists distribution shape, see what - # # the broadcast shape would be. This shape will be the dist_shape - # # used by generate samples (the shape of a single random sample) - # # from the mixture - # comp_dist_shapes = [] - # for dist_shape, comp_dist in zip(self._comp_dist_shapes, self._comp_dists): - # if dist_shape == tuple(): - # # Happens when the distribution is a scalar or when it was - # # not given a shape. In these cases we try to draw a single - # # value to check its shape, we use the provided point - # # dictionary hoping that it can circumvent the Flat and - # # HalfFlat undrawable distributions. - # with _DrawValuesContextBlocker(): - # test_sample = comp_dist.random(point=point, size=None) - # dist_shape = test_sample.shape - # comp_dist_shapes.append(dist_shape) - # # All component distributions must broadcast with each other - # try: - # broadcast_shape = np.broadcast( - # *[np.empty(shape) for shape in comp_dist_shapes] - # ).shape - # except Exception: - # raise TypeError( - # "Inferred comp_dist shapes do not broadcast " - # "with each other. comp_dists inferred shapes " - # "are: {}".format(comp_dist_shapes) - # ) - # return comp_dist_shapes, broadcast_shape - def logp(self, value): - """ - Calculate log-probability of defined Mixture distribution at specified value. +@_logprob.register(MarginalMixtureRV) +def marginal_mixture_logprob(op, values, rng, weights, *components, **kwargs): + (value,) = values - Parameters - ---------- - value: numeric - Value(s) for which log-probability is calculated. If the log probabilities for multiple - values are desired the values must be provided in a numpy array or Aesara tensor + components_logp = at.stack( + [logp(component, value) for component in components], + axis=-1, + ) - Returns - ------- - TensorVariable - """ - w = self.w + # TODO: Is one padding always enough? + if weights.ndim < components_logp.ndim: + weights = at.shape_padaxis(weights, axis=weights.ndim - 1) - return check_parameters( - logsumexp(at.log(w) + self._comp_logp(value), axis=-1, keepdims=False), - w >= 0, - w <= 1, - at.allclose(w.sum(axis=-1), 1), - broadcast_conditions=False, - ) + mix_logp = at.logsumexp(at.log(weights) + components_logp, axis=-1) + # TODO: Do something better than this! + mix_logp = at.squeeze(mix_logp) - def random(self, point=None, size=None): - """ - Draw random values from defined Mixture distribution. + mix_logp = check_parameters( + mix_logp, + 0 <= weights, + weights <= 1, + at.isclose(at.sum(weights, axis=-1), 1), + msg="0 <= weights <= 1, sum(weights) == 1", + ) - Parameters - ---------- - point: dict, optional - Dict of variable values on which random values are to be - conditioned (uses default point if not specified). - size: int, optional - Desired size of random sample (returns one sample if not - specified). + return mix_logp - Returns - ------- - array - """ - # # Convert size to tuple - # size = to_tuple(size) - # # Draw mixture weights and infer the comp_dists shapes - # with _DrawValuesContext() as draw_context: - # # We first need to check w and comp_tmp shapes and re compute size - # w = draw_values([self.w], point=point, size=size)[0] - # comp_dist_shapes, broadcast_shape = self.infer_comp_dist_shapes(point=point) - # - # # When size is not None, it's hard to tell the w parameter shape - # if size is not None and w.shape[: len(size)] == size: - # w_shape = w.shape[len(size) :] - # else: - # w_shape = w.shape - # - # # Try to determine parameter shape and dist_shape - # if self.comp_is_distribution: - # param_shape = np.broadcast(np.empty(w_shape), np.empty(broadcast_shape)).shape - # else: - # param_shape = np.broadcast(np.empty(w_shape), np.empty(broadcast_shape + (1,))).shape - # if np.asarray(self.shape).size != 0: - # dist_shape = np.broadcast(np.empty(self.shape), np.empty(param_shape[:-1])).shape - # else: - # dist_shape = param_shape[:-1] - # - # # Try to determine the size that must be used to get the mixture - # # components (i.e. get random choices using w). - # # 1. There must be size independent choices based on w. - # # 2. There must also be independent draws for each non singleton axis - # # of w. - # # 3. There must also be independent draws for each dimension added by - # # self.shape with respect to the w.ndim. These usually correspond to - # # observed variables with batch shapes - # wsh = (1,) * (len(dist_shape) - len(w_shape) + 1) + w_shape[:-1] - # psh = (1,) * (len(dist_shape) - len(param_shape) + 1) + param_shape[:-1] - # w_sample_size = [] - # # Loop through the dist_shape to get the conditions 2 and 3 first - # for i in range(len(dist_shape)): - # if dist_shape[i] != psh[i] and wsh[i] == 1: - # # self.shape[i] is a non singleton dimension (usually caused by - # # observed data) - # sh = dist_shape[i] - # else: - # sh = wsh[i] - # w_sample_size.append(sh) - # if size is not None and w_sample_size[: len(size)] != size: - # w_sample_size = size + tuple(w_sample_size) - # # Broadcast w to the w_sample_size (add a singleton last axis for the - # # mixture components) - # w = broadcast_distribution_samples([w, np.empty(w_sample_size + (1,))], size=size)[0] - # - # # Semiflatten the mixture weights. The last axis is the number of - # # mixture mixture components, and the rest is all about size, - # # dist_shape and broadcasting - # w_ = np.reshape(w, (-1, w.shape[-1])) - # w_samples = random_choice(p=w_, size=None) # w's shape already includes size - # # Now we broadcast the chosen components to the dist_shape - # w_samples = np.reshape(w_samples, w.shape[:-1]) - # if size is not None and dist_shape[: len(size)] != size: - # w_samples = np.broadcast_to(w_samples, size + dist_shape) - # else: - # w_samples = np.broadcast_to(w_samples, dist_shape) - # - # # When size is not None, maybe dist_shape partially overlaps with size - # if size is not None: - # if size == dist_shape: - # size = None - # elif size[-len(dist_shape) :] == dist_shape: - # size = size[: len(size) - len(dist_shape)] - # - # # We get an integer _size instead of a tuple size for drawing the - # # mixture, then we just reshape the output - # if size is None: - # _size = None - # else: - # _size = int(np.prod(size)) - # - # # Compute the total size of the mixture's random call with size - # if _size is not None: - # output_size = int(_size * np.prod(dist_shape) * param_shape[-1]) - # else: - # output_size = int(np.prod(dist_shape) * param_shape[-1]) - # # Get the size we need for the mixture's random call - # if self.comp_is_distribution: - # mixture_size = int(output_size // np.prod(broadcast_shape)) - # else: - # mixture_size = int(output_size // (np.prod(broadcast_shape) * param_shape[-1])) - # if mixture_size == 1 and _size is None: - # mixture_size = None - # - # # Sample from the mixture - # with draw_context: - # mixed_samples = self._comp_samples( - # point=point, - # size=mixture_size, - # broadcast_shape=broadcast_shape, - # comp_dist_shapes=comp_dist_shapes, - # ) - # # Test that the mixture has the same number of "samples" as w - # if w_samples.size != (mixed_samples.size // w.shape[-1]): - # raise ValueError( - # "Inconsistent number of samples from the " - # "mixture and mixture weights. Drew {} mixture " - # "weights elements, and {} samples from the " - # "mixture components.".format(w_samples.size, mixed_samples.size // w.shape[-1]) - # ) - # # Semiflatten the mixture to be able to zip it with w_samples - # w_samples = w_samples.flatten() - # mixed_samples = np.reshape(mixed_samples, (-1, w.shape[-1])) - # # Select the samples from the mixture - # samples = np.array([mixed[choice] for choice, mixed in zip(w_samples, mixed_samples)]) - # # Reshape the samples to the correct output shape - # if size is None: - # samples = np.reshape(samples, dist_shape) - # else: - # samples = np.reshape(samples, size + dist_shape) - # return samples - def _distr_parameters_for_repr(self): - return [] +@_get_moment.register(MarginalMixtureRV) +def get_moment_marginal_mixture(op, rv, rng, weights, *components): + moment_components = at.stack([get_moment(component) for component in components], axis=-1) + return at.sum(weights * moment_components, axis=-1) class NormalMixture(Mixture): diff --git a/pymc/tests/test_distributions_moments.py b/pymc/tests/test_distributions_moments.py index 0f163d73c21..dc38c1f68bf 100644 --- a/pymc/tests/test_distributions_moments.py +++ b/pymc/tests/test_distributions_moments.py @@ -99,9 +99,7 @@ def test_all_distributions_have_moments(): # Distributions that have not been refactored for V4 yet not_implemented = { - dist_module.mixture.Mixture, dist_module.mixture.MixtureSameFamily, - dist_module.mixture.NormalMixture, dist_module.timeseries.AR, dist_module.timeseries.AR1, dist_module.timeseries.GARCH11, diff --git a/pymc/tests/test_mixture.py b/pymc/tests/test_mixture.py index bd231959d44..28a4bd89fc7 100644 --- a/pymc/tests/test_mixture.py +++ b/pymc/tests/test_mixture.py @@ -41,7 +41,6 @@ from pymc.distributions.shape_utils import to_tuple from pymc.tests.helpers import SeededTest -pytestmark = pytest.mark.xfail(reason="Mixture not refactored.") # Generate data def generate_normal_mixture_data(w, mu, sd, size=1000): @@ -87,11 +86,11 @@ def test_dimensions(self): a2 = Normal.dist(mu=10, sigma=1) mix = Mixture.dist(w=np.r_[0.5, 0.5], comp_dists=[a1, a2]) - assert mix.mode.ndim == 0 - assert mix.logp(0.0).ndim == 0 + assert mix.eval().ndim == 0 + assert pm.logp(mix, 0.0).ndim == 0 value = np.r_[0.0, 1.0, 2.0] - assert mix.logp(value).ndim == 1 + assert pm.logp(mix, value).ndim == 1 def test_mixture_list_of_normals(self): with Model() as model: @@ -105,13 +104,21 @@ def test_mixture_list_of_normals(self): observed=self.norm_x, ) step = Metropolis() - trace = sample(5000, step, random_seed=self.random_seed, progressbar=False, chains=1) + trace = sample( + 5000, + step, + random_seed=self.random_seed, + progressbar=False, + chains=1, + return_inferencedata=False, + ) assert_allclose(np.sort(trace["w"].mean(axis=0)), np.sort(self.norm_w), rtol=0.1, atol=0.1) assert_allclose( np.sort(trace["mu"].mean(axis=0)), np.sort(self.norm_mu), rtol=0.1, atol=0.1 ) + @pytest.mark.xfail(reason="NormalMixture not refactored yet") def test_normal_mixture(self): with Model() as model: w = Dirichlet("w", floatX(np.ones_like(self.norm_w)), shape=self.norm_w.size) @@ -119,13 +126,21 @@ def test_normal_mixture(self): tau = Gamma("tau", 1.0, 1.0, shape=self.norm_w.size) NormalMixture("x_obs", w, mu, tau=tau, observed=self.norm_x) step = Metropolis() - trace = sample(5000, step, random_seed=self.random_seed, progressbar=False, chains=1) + trace = sample( + 5000, + step, + random_seed=self.random_seed, + progressbar=False, + chains=1, + return_inferencedata=False, + ) assert_allclose(np.sort(trace["w"].mean(axis=0)), np.sort(self.norm_w), rtol=0.1, atol=0.1) assert_allclose( np.sort(trace["mu"].mean(axis=0)), np.sort(self.norm_mu), rtol=0.1, atol=0.1 ) + @pytest.mark.xfail(reason="NormalMixture not refactored yet") @pytest.mark.parametrize( "nd,ncomp", [(tuple(), 5), (1, 5), (3, 5), ((3, 3), 5), (3, 3), ((3, 3), 3)], ids=str ) @@ -204,13 +219,21 @@ def test_normal_mixture_nd(self, nd, ncomp): if obs2 is not None: assert_allclose(obs0.logp(testpoint), obs2.logp(testpoint)) + @pytest.mark.xfail(reason="Mixture from single component not refactored yet") def test_poisson_mixture(self): with Model() as model: w = Dirichlet("w", floatX(np.ones_like(self.pois_w)), shape=self.pois_w.shape) mu = Gamma("mu", 1.0, 1.0, shape=self.pois_w.size) Mixture("x_obs", w, Poisson.dist(mu), observed=self.pois_x) step = Metropolis() - trace = sample(5000, step, random_seed=self.random_seed, progressbar=False, chains=1) + trace = sample( + 5000, + step, + random_seed=self.random_seed, + progressbar=False, + chains=1, + return_inferencedata=False, + ) assert_allclose(np.sort(trace["w"].mean(axis=0)), np.sort(self.pois_w), rtol=0.1, atol=0.1) assert_allclose( @@ -223,7 +246,14 @@ def test_mixture_list_of_poissons(self): mu = Gamma("mu", 1.0, 1.0, shape=self.pois_w.size) Mixture("x_obs", w, [Poisson.dist(mu[0]), Poisson.dist(mu[1])], observed=self.pois_x) step = Metropolis() - trace = sample(5000, step, random_seed=self.random_seed, progressbar=False, chains=1) + trace = sample( + 5000, + step, + random_seed=self.random_seed, + progressbar=False, + chains=1, + return_inferencedata=False, + ) assert_allclose(np.sort(trace["w"].mean(axis=0)), np.sort(self.pois_w), rtol=0.1, atol=0.1) assert_allclose( @@ -249,21 +279,20 @@ def test_mixture_of_mvn(self): st.multivariate_normal.logpdf(obs, mu2, cov2), ) ).T - complogp = y.distribution._comp_logp(aesara.shared(obs)).eval() - assert_allclose(complogp, complogp_st) # check logp of mixture testpoint = model.compute_initial_point() mixlogp_st = logsumexp(np.log(testpoint["w"]) + complogp_st, axis=-1, keepdims=False) - assert_allclose(y.logp_elemwise(testpoint), mixlogp_st) + assert_allclose(model.compile_logp(y, sum=False)(testpoint)[0], mixlogp_st) # check logp of model priorlogp = st.dirichlet.logpdf( x=testpoint["w"], alpha=np.ones(2), ) - assert_allclose(model.logp(testpoint), mixlogp_st.sum() + priorlogp) + assert_allclose(model.compile_logp()(testpoint), mixlogp_st.sum() + priorlogp) + @pytest.mark.xfail(reason="Mixture from single component not refactored yet") def test_mixture_of_mixture(self): if aesara.config.floatX == "float32": rtol = 1e-4 @@ -365,10 +394,14 @@ def build_toy_dataset(N, K): packed_chol = [] chol = [] for i in range(K): - mu.append(pm.Normal("mu%i" % i, 0, 10, shape=D)) + mu.append(pm.Normal(f"mu{i}", 0, 10, shape=D)) packed_chol.append( pm.LKJCholeskyCov( - "chol_cov_%i" % i, eta=2, n=D, sd_dist=pm.HalfNormal.dist(2.5, size=D) + f"chol_cov_{i}", + eta=2, + n=D, + sd_dist=pm.HalfNormal.dist(2.5, size=D), + compute_corr=False, ) ) chol.append(pm.expand_packed_triangular(D, packed_chol[i], lower=True)) @@ -380,14 +413,80 @@ def build_toy_dataset(N, K): n_samples = 20 with model: - ppc = pm.sample_posterior_predictive(idata, n_samples) - prior = pm.sample_prior_predictive(samples=n_samples) + ppc = pm.sample_posterior_predictive(idata, n_samples, return_inferencedata=False) + prior = pm.sample_prior_predictive(samples=n_samples, return_inferencedata=False) assert ppc["x_obs"].shape == (n_samples,) + X.shape assert prior["x_obs"].shape == (n_samples,) + X.shape assert prior["mu0"].shape == (n_samples, D) assert prior["chol_cov_0"].shape == (n_samples, D * (D + 1) // 2) - + @pytest.mark.parametrize( + "weights, components, size, expected_shape", + [ + ( + pm.Dirichlet.dist(a=[100, 1]), + [pm.Normal.dist(-10, 0.001), pm.Normal.dist(10, 0.001)], + None, + (), + ), + ( + pm.Dirichlet.dist(a=[100, 1]), + [pm.Normal.dist(-10, 0.001), pm.Normal.dist(10, 0.001)], + (1,), + (1,), + ), + ( + pm.Dirichlet.dist(a=[100, 1]), + [pm.Normal.dist(-10, 0.001), pm.Normal.dist(10, 0.001)], + (2,), + (2,), + ), + ( + pm.Dirichlet.dist(a=[100, 1]), + [pm.Normal.dist(-10, 0.001), pm.Normal.dist(10, 0.001)], + (3, 2), + (3, 2), + ), + ( + pm.Dirichlet.dist(a=[100, 1]), + [pm.Normal.dist([-10, -5], 0.001), pm.Normal.dist([10, 5], 0.001)], + (3, 2), + (3, 2), + ), + ( + pm.Dirichlet.dist(a=[100, 1], size=3), + [pm.Normal.dist([-10, -5], 0.001), pm.Normal.dist([10, 5], 0.001)], + (3, 2), + (3, 2), + ), + ( + pm.Dirichlet.dist(a=[[100, 1], [1, 100], [100, 1]]), + [pm.Normal.dist([-10, -5], 0.001), pm.Normal.dist([10, 5], 0.001)], + (3, 2), + (3, 2), + ), + ], + ) + def test_mixture_size(self, weights, components, size, expected_shape): + # TODO: Test values come from expected components + mix = pm.Mixture.dist(weights, components, size=size) + mix_eval = mix.eval() + # TODO: Remove this + # print(mix_eval) + assert mix_eval.shape == expected_shape + + def test_mixture_choices_random(self): + # Test that mixture choices change over evaluations + with pm.Model() as m: + weights = [0.5, 0.5] + components = [pm.Normal.dist(-10, 0.01), pm.Normal.dist(10, 0.01)] + mix = pm.Mixture.dist(weights, components) + draws = pm.draw(mix, draws=10) + # Probability of coming from same component 10 times is 0.5**10 + assert np.unique(draws > 0).size == 2 + + +@pytest.mark.xfail(reason="NormalMixture not refactored yet") class TestMixtureVsLatent(SeededTest): def setup_method(self, *args, **kwargs): super().setup_method(*args, **kwargs) @@ -496,6 +595,7 @@ def logp_matches(self, mixture, latent_mix, z, npop, model): assert_allclose(mix_logp, latent_mix_logp, rtol=rtol) +@pytest.mark.xfail(reason="MixtureSameFamily not refactored yet") class TestMixtureSameFamily(SeededTest): @classmethod def setup_class(cls):