Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Misc updates #5738

Merged
merged 5 commits into from
May 4, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import numpy as np
import scipy.sparse as sps

from aeppl.abstract import MeasurableVariable
from aeppl.logprob import CheckParameterValue
from aesara import config, scalar
from aesara.compile.mode import Mode, get_mode
Expand Down Expand Up @@ -978,14 +979,21 @@ def compile_pymc(
# TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
rng_updates = {}
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
for rv in (
node
for node in vars_between(inputs, output_to_list)
if node.owner and isinstance(node.owner.op, RandomVariable) and node not in inputs
for random_var in (
var
for var in vars_between(inputs, output_to_list)
if var.owner
and isinstance(var.owner.op, (RandomVariable, MeasurableVariable))
and var not in inputs
):
rng = rv.owner.inputs[0]
if not hasattr(rng, "default_update"):
rng_updates[rng] = rv.owner.outputs[0]
if isinstance(random_var.owner.op, RandomVariable):
rng = random_var.owner.inputs[0]
if not hasattr(rng, "default_update"):
rng_updates[rng] = random_var.owner.outputs[0]
else:
update_fn = getattr(random_var.owner.op, "update", None)
if update_fn is not None:
rng_updates.update(update_fn(random_var.owner))

# If called inside a model context, see if check_bounds flag is set to False
try:
Expand Down
6 changes: 3 additions & 3 deletions pymc/distributions/bound.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,7 @@
from pymc.distributions.continuous import BoundedContinuous, bounded_cont_transform
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import Continuous, Discrete
from pymc.distributions.logprob import logp
from pymc.distributions.logprob import ignore_logprob, logp
from pymc.distributions.shape_utils import to_tuple
from pymc.distributions.transforms import _default_transform
from pymc.model import modelcontext
Expand Down Expand Up @@ -193,7 +193,7 @@ def __new__(
raise ValueError("Given dims do not exist in model coordinates.")

lower, upper, initval = cls._set_values(lower, upper, size, shape, initval)
dist.tag.ignore_logprob = True
dist = ignore_logprob(dist)

if isinstance(dist.owner.op, Continuous):
res = _ContinuousBounded(
Expand Down Expand Up @@ -228,7 +228,7 @@ def dist(

cls._argument_checks(dist, **kwargs)
lower, upper, initval = cls._set_values(lower, upper, size, shape, initval=None)
dist.tag.ignore_logprob = True
dist = ignore_logprob(dist)
if isinstance(dist.owner.op, Continuous):
res = _ContinuousBounded.dist(
[dist, lower, upper],
Expand Down
9 changes: 6 additions & 3 deletions pymc/distributions/censored.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,9 +42,12 @@ class Censored(SymbolicDistribution):

Parameters
----------
dist: PyMC unnamed distribution
PyMC distribution created via the `.dist()` API, which will be censored. This
distribution must be univariate and have a logcdf method implemented.
dist: unnamed distribution
Univariate distribution created via the `.dist()` API, which will be censored.
This distribution must have a logcdf method implemented for sampling.

.. warning:: dist will be cloned, rendering it independent of the one passed as input.

lower: float or None
Lower (left) censoring point. If `None` the distribution will not be left censored
upper: float or None
Expand Down
24 changes: 23 additions & 1 deletion pymc/distributions/logprob.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@
import numpy as np

from aeppl import factorized_joint_logprob
from aeppl.abstract import assign_custom_measurable_outputs
from aeppl.logprob import logcdf as logcdf_aeppl
from aeppl.logprob import logprob as logp_aeppl
from aeppl.transforms import TransformValuesOpt
Expand Down Expand Up @@ -221,7 +222,11 @@ def joint_logpt(

transform_opt = TransformValuesOpt(transform_map)
temp_logp_var_dict = factorized_joint_logprob(
tmp_rvs_to_values, extra_rewrites=transform_opt, use_jacobian=jacobian, **kwargs
tmp_rvs_to_values,
extra_rewrites=transform_opt,
use_jacobian=jacobian,
warn_missing_rvs=False,
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of warning, it will fail which is a more strict behavior

**kwargs,
)

# Raise if there are unexpected RandomVariables in the logp graph
Expand Down Expand Up @@ -276,3 +281,20 @@ def logcdf(rv, value):

value = at.as_tensor_variable(value, dtype=rv.dtype)
return logcdf_aeppl(rv, value)


def ignore_logprob(rv):
"""Return a duplicated variable that is ignored when creating Aeppl logprob graphs

This is used in SymbolicDistributions that use other RVs as inputs but account
for their logp terms explicitly.

If the variable is already ignored, it is returned directly.
"""
prefix = "Unmeasurable"
node = rv.owner
op_type = type(node.op)
if op_type.__name__.startswith(prefix):
return rv
new_node = assign_custom_measurable_outputs(node, type_prefix=prefix)
return new_node.outputs[node.outputs.index(rv)]
30 changes: 16 additions & 14 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from aeppl.logprob import _logcdf, _logprob
from aeppl.transforms import IntervalTransform
from aesara.compile.builders import OpFromGraph
from aesara.graph.basic import equal_computations
from aesara.graph.basic import Node, equal_computations
from aesara.tensor import TensorVariable
from aesara.tensor.random.op import RandomVariable

Expand All @@ -30,7 +30,7 @@
from pymc.distributions.continuous import Normal, get_tau_sigma
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.distribution import SymbolicDistribution, _moment, moment
from pymc.distributions.logprob import logcdf, logp
from pymc.distributions.logprob import ignore_logprob, logcdf, logp
from pymc.distributions.shape_utils import to_tuple
from pymc.distributions.transforms import _default_transform
from pymc.util import check_dist_not_registered
Expand All @@ -44,6 +44,10 @@ class MarginalMixtureRV(OpFromGraph):

default_output = 1

def update(self, node: Node):
# Update for the internal mix_indexes RV
return {node.inputs[0]: node.outputs[0]}


MeasurableVariable.register(MarginalMixtureRV)

Expand All @@ -66,12 +70,15 @@ class Mixture(SymbolicDistribution):
w : tensor_like of float
w >= 0 and w <= 1
the mixture weights
comp_dists : iterable of PyMC distributions or single batched distribution
Distributions should be created via the `.dist()` API. If single distribution is
passed, the last size dimension (not shape) determines the number of mixture
comp_dists : iterable of unnamed distributions or single batched distribution
Distributions should be created via the `.dist()` API. If a single distribution
is passed, the last size dimension (not shape) determines the number of mixture
components (e.g. `pm.Poisson.dist(..., size=components)`)
:math:`f_1, \ldots, f_n`

.. warning:: comp_dists will be cloned, rendering them independent of the ones passed as input.


Examples
--------
.. code-block:: python
Expand Down Expand Up @@ -249,6 +256,10 @@ def rv_op(cls, weights, *components, size=None, rngs=None):

assert weights_ndim_batch == 0

# Component RVs terms are accounted by the Mixture logprob, so they can be
# safely ignored by Aeppl
components = [ignore_logprob(component) for component in components]

# Create a OpFromGraph that encapsulates the random generating process
# Create dummy input variables with the same type as the ones provided
weights_ = weights.type()
Expand Down Expand Up @@ -287,20 +298,11 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
# Create the actual MarginalMixture variable
mix_out = mix_op(mix_indexes_rng, weights, *components)

# We need to set_default_updates ourselves, because the choices RV is hidden
# inside OpFromGraph and PyMC will never find it otherwise
mix_indexes_rng.default_update = mix_out.owner.outputs[0]

# Reference nodes to facilitate identification in other classmethods
mix_out.tag.weights = weights
mix_out.tag.components = components
mix_out.tag.choices_rng = mix_indexes_rng

# Component RVs terms are accounted by the Mixture logprob, so they can be
# safely ignore by Aeppl (this tag prevents UserWarning)
for component in components:
component.tag.ignore_logprob = True

return mix_out

@classmethod
Expand Down
10 changes: 6 additions & 4 deletions pymc/distributions/multivariate.py
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
multigammaln,
)
from pymc.distributions.distribution import Continuous, Discrete, moment
from pymc.distributions.logprob import ignore_logprob
from pymc.distributions.shape_utils import (
broadcast_dist_samples_to,
rv_size_is_none,
Expand Down Expand Up @@ -1182,11 +1183,9 @@ def dist(cls, eta, n, sd_dist, **kwargs):

# sd_dist is part of the generative graph, but should be completely ignored
# by the logp graph, since the LKJ logp explicitly includes these terms.
# Setting sd_dist.tag.ignore_logprob to True, will prevent Aeppl warning about
# an unnacounted RandomVariable in the graph
# TODO: Things could be simplified a bit if we managed to extract the
# sd_dist prior components from the logp expression.
sd_dist.tag.ignore_logprob = True
sd_dist = ignore_logprob(sd_dist)

return super().dist([n, eta, sd_dist], **kwargs)

Expand Down Expand Up @@ -1271,10 +1270,13 @@ class LKJCholeskyCov:
larger values put more weight on matrices with few correlations.
n: int
Dimension of the covariance matrix (n > 1).
sd_dist: pm.Distribution
sd_dist: unnamed distribution
A positive scalar or vector distribution for the standard deviations, created
with the `.dist()` API. Should have `shape[-1]=n`. Scalar distributions will be
automatically resized to ensure this.

.. warning:: sd_dist will be cloned, rendering it independent of the one passed as input.

compute_corr: bool, default=True
If `True`, returns three values: the Cholesky decomposition, the correlations
and the standard deviations of the covariance matrix. Otherwise, only returns
Expand Down
2 changes: 0 additions & 2 deletions pymc/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -247,8 +247,6 @@ def logp(cls, value, sim_op, sim_inputs):
# in which case this would not be needed. However, that would have to be
# done for every sampler that may accomodate Simulators
rng = aesara.shared(np.random.default_rng())
rng.tag.is_rng = True

# Create a new simulatorRV with identical inputs as the original one
sim_value = sim_op.make_node(rng, *sim_inputs[1:]).default_output()
sim_value.name = "sim_value"
Expand Down
8 changes: 6 additions & 2 deletions pymc/distributions/timeseries.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from pymc.distributions import distribution, logprob, multivariate
from pymc.distributions.continuous import Flat, Normal, get_tau_sigma
from pymc.distributions.dist_math import check_parameters
from pymc.distributions.logprob import ignore_logprob
from pymc.distributions.shape_utils import rv_size_is_none, to_tuple
from pymc.util import check_dist_not_registered

Expand Down Expand Up @@ -147,9 +148,12 @@ class GaussianRandomWalk(distribution.Continuous):
innovation drift, defaults to 0.0
sigma : tensor_like of float, optional
sigma > 0, innovation standard deviation, defaults to 1.0
init : Univariate PyMC distribution
init : unnamed distribution
Univariate distribution of the initial value, created with the `.dist()` API.
Defaults to Normal with same `mu` and `sigma` as the GaussianRandomWalk

.. warning:: init will be cloned, rendering them independent of the ones passed as input.

steps : int
Number of steps in Gaussian Random Walks (steps > 0).
"""
Expand Down Expand Up @@ -203,7 +207,7 @@ def dist(
raise TypeError("init must be a univariate distribution variable")

# Ignores logprob of init var because that's accounted for in the logp method
init.tag.ignore_logprob = True
init = ignore_logprob(init)

return super().dist([mu, sigma, init, steps], size=size, **kwargs)

Expand Down
10 changes: 3 additions & 7 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,13 +1363,9 @@ def make_obs_var(
# size of the masked and unmasked array happened to coincide
_, size, _, *inps = observed_rv_var.owner.inputs
rng = self.model.next_rng()
observed_rv_var = observed_rv_var.owner.op(*inps, size=size, rng=rng)
# Add default_update to new rng
new_rng = observed_rv_var.owner.outputs[0]
observed_rv_var.update = (rng, new_rng)
rng.default_update = new_rng
observed_rv_var.name = f"{name}_observed"

observed_rv_var = observed_rv_var.owner.op(
*inps, size=size, rng=rng, name=f"{name}_observed"
)
observed_rv_var.tag.observations = nonmissing_data

self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
Expand Down
Loading