diff --git a/pymc/aesaraf.py b/pymc/aesaraf.py index b263a20b9bf..e2de6d9fe27 100644 --- a/pymc/aesaraf.py +++ b/pymc/aesaraf.py @@ -30,6 +30,7 @@ import numpy as np import scipy.sparse as sps +from aeppl.abstract import MeasurableVariable from aeppl.logprob import CheckParameterValue from aesara import config, scalar from aesara.compile.mode import Mode, get_mode @@ -978,14 +979,21 @@ def compile_pymc( # TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph) rng_updates = {} output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs] - for rv in ( - node - for node in vars_between(inputs, output_to_list) - if node.owner and isinstance(node.owner.op, RandomVariable) and node not in inputs + for random_var in ( + var + for var in vars_between(inputs, output_to_list) + if var.owner + and isinstance(var.owner.op, (RandomVariable, MeasurableVariable)) + and var not in inputs ): - rng = rv.owner.inputs[0] - if not hasattr(rng, "default_update"): - rng_updates[rng] = rv.owner.outputs[0] + if isinstance(random_var.owner.op, RandomVariable): + rng = random_var.owner.inputs[0] + if not hasattr(rng, "default_update"): + rng_updates[rng] = random_var.owner.outputs[0] + else: + update_fn = getattr(random_var.owner.op, "update", None) + if update_fn is not None: + rng_updates.update(update_fn(random_var.owner)) # If called inside a model context, see if check_bounds flag is set to False try: diff --git a/pymc/distributions/mixture.py b/pymc/distributions/mixture.py index 77f5e4dee76..b613f90bac7 100644 --- a/pymc/distributions/mixture.py +++ b/pymc/distributions/mixture.py @@ -21,7 +21,7 @@ from aeppl.logprob import _logcdf, _logprob from aeppl.transforms import IntervalTransform from aesara.compile.builders import OpFromGraph -from aesara.graph.basic import equal_computations +from aesara.graph.basic import Node, equal_computations from aesara.tensor import TensorVariable from aesara.tensor.random.op import RandomVariable @@ -44,6 +44,10 @@ class MarginalMixtureRV(OpFromGraph): default_output = 1 + def update(self, node: Node): + # Update for the internal mix_indexes RV + return {node.inputs[0]: node.outputs[0]} + MeasurableVariable.register(MarginalMixtureRV) @@ -294,10 +298,6 @@ def rv_op(cls, weights, *components, size=None, rngs=None): # Create the actual MarginalMixture variable mix_out = mix_op(mix_indexes_rng, weights, *components) - # We need to set_default_updates ourselves, because the choices RV is hidden - # inside OpFromGraph and PyMC will never find it otherwise - mix_indexes_rng.default_update = mix_out.owner.outputs[0] - # Reference nodes to facilitate identification in other classmethods mix_out.tag.weights = weights mix_out.tag.components = components diff --git a/pymc/model.py b/pymc/model.py index 26b8ed66008..d13091be694 100644 --- a/pymc/model.py +++ b/pymc/model.py @@ -1363,13 +1363,9 @@ def make_obs_var( # size of the masked and unmasked array happened to coincide _, size, _, *inps = observed_rv_var.owner.inputs rng = self.model.next_rng() - observed_rv_var = observed_rv_var.owner.op(*inps, size=size, rng=rng) - # Add default_update to new rng - new_rng = observed_rv_var.owner.outputs[0] - observed_rv_var.update = (rng, new_rng) - rng.default_update = new_rng - observed_rv_var.name = f"{name}_observed" - + observed_rv_var = observed_rv_var.owner.op( + *inps, size=size, rng=rng, name=f"{name}_observed" + ) observed_rv_var.tag.observations = nonmissing_data self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data) diff --git a/pymc/tests/test_aesaraf.py b/pymc/tests/test_aesaraf.py index ef05068879c..709f21c6ef9 100644 --- a/pymc/tests/test_aesaraf.py +++ b/pymc/tests/test_aesaraf.py @@ -22,7 +22,9 @@ import pytest import scipy.sparse as sps +from aeppl.abstract import MeasurableVariable from aeppl.logprob import ParameterValueError +from aesara.compile.builders import OpFromGraph from aesara.graph.basic import Constant, Variable, ancestors, equal_computations from aesara.tensor.random.basic import normal, uniform from aesara.tensor.random.op import RandomVariable @@ -681,3 +683,24 @@ def test_compile_pymc_updates_inputs(self): assert len(fn_fgraph.apply_nodes) == max(rvs_in_graph, 1) # Each RV adds a shared output for its rng assert len(fn_fgraph.outputs) == 1 + rvs_in_graph + + def test_compile_pymc_custom_update_op(self): + """Test that custom MeasurableVariable Op updates are used by compile_pymc""" + + class UnmeasurableOp(OpFromGraph): + def update(self, node): + return {node.inputs[0]: node.inputs[0] + 1} + + dummy_inputs = [at.scalar(), at.scalar()] + dummy_outputs = [at.add(*dummy_inputs)] + dummy_x = UnmeasurableOp(dummy_inputs, dummy_outputs)(aesara.shared(1.0), 1.0) + + # Check that there are no updates at first + fn = compile_pymc(inputs=[], outputs=dummy_x) + assert fn() == fn() == 2.0 + + # And they are enabled once the Op is registered as Measurable + MeasurableVariable.register(UnmeasurableOp) + fn = compile_pymc(inputs=[], outputs=dummy_x) + assert fn() == 2.0 + assert fn() == 3.0