Skip to content

Commit

Permalink
Standardize Simulator logp
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jul 6, 2021
1 parent 7e01fd2 commit 729daea
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 14 deletions.
37 changes: 28 additions & 9 deletions pymc3/distributions/simulator.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,13 +34,31 @@
class SimulatorRV(RandomVariable):
"""A placeholder for Simulator RVs"""

name = "SimulatorRV"
_print_name = ("Simulator", "\\operatorname{Simulator}")
fn = None
epsilon = None
distance = None
sum_stat = None

@classmethod
def rng_fn(cls, *args, **kwargs):
if cls.fn is None:
raise ValueError(f"fn was not defined for {cls}")
return cls.fn(*args, **kwargs)

@classmethod
def _distance(cls, epsilon, value, sim_value):
if cls.distance is None:
raise ValueError(f"distance function was not defined for {cls}")
return cls.distance(epsilon, value, sim_value)

@classmethod
def _sum_stat(cls, value):
if cls.sum_stat is None:
raise ValueError(f"sum_stat function was not defined for {cls}")
return cls.sum_stat(value)


class Simulator(NoDistribution):
r"""
Expand Down Expand Up @@ -138,9 +156,9 @@ def __new__(
inplace=False,
# Specifc to Simulator
fn=fn,
# distance=distance,
# sum_stat=sum_stat,
# epsilon=epsilon,
distance=distance,
sum_stat=sum_stat,
epsilon=epsilon,
),
)()

Expand All @@ -150,24 +168,25 @@ def __new__(
@_logp.register(rv_type)
def logp(op, sim_rv, rvs_to_values, *sim_params, **kwargs):
value_var = rvs_to_values.get(sim_rv, sim_rv)
return cls.logp(
return Simulator.logp(
value_var,
sim_rv,
distance,
sum_stat,
epsilon,
)

cls.rv_op = sim_op
return super().__new__(cls, name, params, observed=observed, **kwargs)

@classmethod
def logp(cls, value, sim_rv, distance, sum_stat, epsilon):
def logp(cls, value, sim_rv):
# Create a new simulatorRV identically to the original one
sim_op = sim_rv.owner.op
sim_data = at.as_tensor_variable(sim_op.make_node(*sim_rv.owner.inputs))
sim_data.name = "sim_data"
return distance(epsilon, sum_stat(value), sum_stat(sim_data))
return sim_op._distance(
sim_op.epsilon,
sim_op._sum_stat(value),
sim_op._sum_stat(sim_data),
)


def identity(x):
Expand Down
20 changes: 15 additions & 5 deletions pymc3/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,8 @@
import numpy as np
import pytest

from aesara.graph.basic import ancestors
from aesara.tensor.random.op import RandomVariable
from aesara.tensor.sort import SortOp
from arviz.data.inference_data import InferenceData

Expand Down Expand Up @@ -336,11 +338,6 @@ def fn(rng, a, size):
observed=data2,
)

trace = pm.sample_smc(chains=1, return_inferencedata=False)

assert abs(true_a - trace["a"].mean()) < 0.05
assert abs(true_b - trace["b"].mean()) < 0.05

# Check that the logps use the correct methods
sim1_val = m.rvs_to_values[sim1]
logp_sim1 = _logp(sim1.owner.op, sim1, {sim1: sim1_val})
Expand All @@ -358,6 +355,19 @@ def fn(rng, a, size):
node for node in logp_sim2_fn.maker.fgraph.toposort() if isinstance(node.op, SortOp)
)

# Check that there are two RandomVariables in the graph
rvs = [
node
for node in ancestors([m.logpt])
if node.owner and isinstance(node.owner.op, RandomVariable)
]
assert len(rvs) == 2

with m:
trace = pm.sample_smc(chains=1, return_inferencedata=False)
assert abs(true_a - trace["a"].mean()) < 0.05
assert abs(true_b - trace["b"].mean()) < 0.05

def test_depracated_abc_args(self):
with self.SMABC_test:
with pytest.warns(
Expand Down

0 comments on commit 729daea

Please sign in to comment.