Skip to content

Commit

Permalink
Add example test with more than 1 Simulator
Browse files Browse the repository at this point in the history
  • Loading branch information
ricardoV94 committed Jun 24, 2021
1 parent 3a6a0c1 commit 093e030
Showing 1 changed file with 58 additions and 0 deletions.
58 changes: 58 additions & 0 deletions pymc3/tests/test_smc.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,11 @@
import numpy as np
import pytest

from aesara.tensor.sort import SortOp

import pymc3 as pm

from pymc3.distributions.logp import _logp
from pymc3.tests.helpers import SeededTest


Expand Down Expand Up @@ -230,3 +233,58 @@ def test_simulator_metropolis_mcmc(self):

assert abs(self.data.mean() - trace["a"].mean()) < 0.05
assert abs(self.data.std() - trace["b"].mean()) < 0.05

def test_multiple_simulators(self):
def fn(rng, a, size):
return rng.normal(a, 0.1, size=size)

true_a = 2
true_b = -2

data1 = np.random.normal(true_a, 0.1, size=100)
data2 = np.random.normal(true_b, 0.1, size=100)

with pm.Model() as m:
a = pm.Normal("a", 0, 1)
b = pm.Normal("b", 1)

sim1 = pm.Simulator(
"sim1",
fn,
params=(a,),
sum_stat="sort",
distance="gaussian",
epsilon=1,
observed=data1,
)
sim2 = pm.Simulator(
"sim2",
fn,
params=(b,),
sum_stat="identity",
distance="laplace",
epsilon=0.5,
observed=data2,
)

trace = pm.sample_smc(chains=1)

assert abs(true_a - trace["a"].mean()) < 0.05
assert abs(true_b - trace["b"].mean()) < 0.05

# Check that the logps use the correct methods
sim1_val = m.rvs_to_values[sim1]
logp_sim1 = _logp(sim1.owner.op, sim1, {sim1: sim1_val})
logp_sim1_fn = aesara.function([sim1_val], logp_sim1)

sim2_val = m.rvs_to_values[sim2]
logp_sim2 = _logp(sim2.owner.op, sim2, {sim2: sim2_val})
logp_sim2_fn = aesara.function([sim2_val], logp_sim2)

assert any(
node for node in logp_sim1_fn.maker.fgraph.toposort() if isinstance(node.op, SortOp)
)

assert not any(
node for node in logp_sim2_fn.maker.fgraph.toposort() if isinstance(node.op, SortOp)
)

0 comments on commit 093e030

Please sign in to comment.