Skip to content

Commit

Permalink
Remove remaining uses of default_updates in codebase
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed May 4, 2022
1 parent ee73231 commit a3bd083
Show file tree
Hide file tree
Showing 4 changed files with 46 additions and 19 deletions.
22 changes: 15 additions & 7 deletions pymc/aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
import numpy as np
import scipy.sparse as sps

from aeppl.abstract import MeasurableVariable
from aeppl.logprob import CheckParameterValue
from aesara import config, scalar
from aesara.compile.mode import Mode, get_mode
Expand Down Expand Up @@ -978,14 +979,21 @@ def compile_pymc(
# TODO: This won't work for variables with InnerGraphs (Scan and OpFromGraph)
rng_updates = {}
output_to_list = outputs if isinstance(outputs, (list, tuple)) else [outputs]
for rv in (
node
for node in vars_between(inputs, output_to_list)
if node.owner and isinstance(node.owner.op, RandomVariable) and node not in inputs
for random_var in (
var
for var in vars_between(inputs, output_to_list)
if var.owner
and isinstance(var.owner.op, (RandomVariable, MeasurableVariable))
and var not in inputs
):
rng = rv.owner.inputs[0]
if not hasattr(rng, "default_update"):
rng_updates[rng] = rv.owner.outputs[0]
if isinstance(random_var.owner.op, RandomVariable):
rng = random_var.owner.inputs[0]
if not hasattr(rng, "default_update"):
rng_updates[rng] = random_var.owner.outputs[0]
else:
update_fn = getattr(random_var.owner.op, "update", None)
if update_fn is not None:
rng_updates.update(update_fn(random_var.owner))

# If called inside a model context, see if check_bounds flag is set to False
try:
Expand Down
10 changes: 5 additions & 5 deletions pymc/distributions/mixture.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@
from aeppl.logprob import _logcdf, _logprob
from aeppl.transforms import IntervalTransform
from aesara.compile.builders import OpFromGraph
from aesara.graph.basic import equal_computations
from aesara.graph.basic import Node, equal_computations
from aesara.tensor import TensorVariable
from aesara.tensor.random.op import RandomVariable

Expand All @@ -44,6 +44,10 @@ class MarginalMixtureRV(OpFromGraph):

default_output = 1

def update(self, node: Node):
# Update for the internal mix_indexes RV
return {node.inputs[0]: node.outputs[0]}


MeasurableVariable.register(MarginalMixtureRV)

Expand Down Expand Up @@ -294,10 +298,6 @@ def rv_op(cls, weights, *components, size=None, rngs=None):
# Create the actual MarginalMixture variable
mix_out = mix_op(mix_indexes_rng, weights, *components)

# We need to set_default_updates ourselves, because the choices RV is hidden
# inside OpFromGraph and PyMC will never find it otherwise
mix_indexes_rng.default_update = mix_out.owner.outputs[0]

# Reference nodes to facilitate identification in other classmethods
mix_out.tag.weights = weights
mix_out.tag.components = components
Expand Down
10 changes: 3 additions & 7 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,13 +1363,9 @@ def make_obs_var(
# size of the masked and unmasked array happened to coincide
_, size, _, *inps = observed_rv_var.owner.inputs
rng = self.model.next_rng()
observed_rv_var = observed_rv_var.owner.op(*inps, size=size, rng=rng)
# Add default_update to new rng
new_rng = observed_rv_var.owner.outputs[0]
observed_rv_var.update = (rng, new_rng)
rng.default_update = new_rng
observed_rv_var.name = f"{name}_observed"

observed_rv_var = observed_rv_var.owner.op(
*inps, size=size, rng=rng, name=f"{name}_observed"
)
observed_rv_var.tag.observations = nonmissing_data

self.create_value_var(observed_rv_var, transform=None, value_var=nonmissing_data)
Expand Down
23 changes: 23 additions & 0 deletions pymc/tests/test_aesaraf.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,9 @@
import pytest
import scipy.sparse as sps

from aeppl.abstract import MeasurableVariable
from aeppl.logprob import ParameterValueError
from aesara.compile.builders import OpFromGraph
from aesara.graph.basic import Constant, Variable, ancestors, equal_computations
from aesara.tensor.random.basic import normal, uniform
from aesara.tensor.random.op import RandomVariable
Expand Down Expand Up @@ -681,3 +683,24 @@ def test_compile_pymc_updates_inputs(self):
assert len(fn_fgraph.apply_nodes) == max(rvs_in_graph, 1)
# Each RV adds a shared output for its rng
assert len(fn_fgraph.outputs) == 1 + rvs_in_graph

def test_compile_pymc_custom_update_op(self):
"""Test that custom MeasurableVariable Op updates are used by compile_pymc"""

class UnmeasurableOp(OpFromGraph):
def update(self, node):
return {node.inputs[0]: node.inputs[0] + 1}

dummy_inputs = [at.scalar(), at.scalar()]
dummy_outputs = [at.add(*dummy_inputs)]
dummy_x = UnmeasurableOp(dummy_inputs, dummy_outputs)(aesara.shared(1.0), 1.0)

# Check that there are no updates at first
fn = compile_pymc(inputs=[], outputs=dummy_x)
assert fn() == fn() == 2.0

# And they are enabled once the Op is registered as Measurable
MeasurableVariable.register(UnmeasurableOp)
fn = compile_pymc(inputs=[], outputs=dummy_x)
assert fn() == 2.0
assert fn() == 3.0

0 comments on commit a3bd083

Please sign in to comment.