Skip to content

Commit

Permalink
Remove Model auto_deterministics
Browse files Browse the repository at this point in the history
This property was initially added just to handle deterministics created by automatic imputation, in order to ensure the combined tensor of missing and observed components showed up in prior and posterior predictive sampling. At the same time, it allowed hiding the deterministic during mcmc sampling, saving memory use for large datasets. This last benefit is lost for the sake of simplicity. If a user is concerned, they can manually split the observed and missing components of a dataset when defining their model.
  • Loading branch information
ricardoV94 committed Nov 17, 2022
1 parent f1a94d2 commit 0121fd9
Show file tree
Hide file tree
Showing 4 changed files with 89 additions and 27 deletions.
12 changes: 4 additions & 8 deletions pymc/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -560,7 +560,6 @@ def __init__(
self.rvs_to_initial_values = treedict(parent=self.parent.rvs_to_initial_values)
self.free_RVs = treelist(parent=self.parent.free_RVs)
self.observed_RVs = treelist(parent=self.parent.observed_RVs)
self.auto_deterministics = treelist(parent=self.parent.auto_deterministics)
self.deterministics = treelist(parent=self.parent.deterministics)
self.potentials = treelist(parent=self.parent.potentials)
self._coords = self.parent._coords
Expand All @@ -575,7 +574,6 @@ def __init__(
self.rvs_to_initial_values = treedict()
self.free_RVs = treelist()
self.observed_RVs = treelist()
self.auto_deterministics = treelist()
self.deterministics = treelist()
self.potentials = treelist()
self._coords = {}
Expand Down Expand Up @@ -1435,10 +1433,11 @@ def make_obs_var(
self.observed_RVs.append(observed_rv_var)

# Create deterministic that combines observed and missing
# Note: This can widely increase memory consumption during sampling for large datasets
rv_var = at.zeros(data.shape)
rv_var = at.set_subtensor(rv_var[mask.nonzero()], missing_rv_var)
rv_var = at.set_subtensor(rv_var[antimask_idx], observed_rv_var)
rv_var = Deterministic(name, rv_var, self, dims, auto=True)
rv_var = Deterministic(name, rv_var, self, dims)

else:
if sps.issparse(data):
Expand Down Expand Up @@ -1908,7 +1907,7 @@ def Point(*args, filter_model_vars=False, **kwargs) -> Dict[str, np.ndarray]:
}


def Deterministic(name, var, model=None, dims=None, auto=False):
def Deterministic(name, var, model=None, dims=None):
"""Create a named deterministic variable.
Deterministic nodes are only deterministic given all of their inputs, i.e.
Expand Down Expand Up @@ -1971,10 +1970,7 @@ def Deterministic(name, var, model=None, dims=None, auto=False):
"""
model = modelcontext(model)
var = var.copy(model.name_for(name))
if auto:
model.auto_deterministics.append(var)
else:
model.deterministics.append(var)
model.deterministics.append(var)
model.add_named_random_variable(var, dims)

from pymc.printing import str_for_potential_or_deterministic
Expand Down
29 changes: 23 additions & 6 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,14 @@
import xarray

from aesara import tensor as at
from aesara.graph.basic import Apply, Constant, Variable, general_toposort, walk
from aesara.graph.basic import (
Apply,
Constant,
Variable,
ancestors,
general_toposort,
walk,
)
from aesara.graph.fg import FunctionGraph
from aesara.tensor.random.var import (
RandomGeneratorSharedVariable,
Expand Down Expand Up @@ -324,6 +331,18 @@ def draw(
return [np.stack(v) for v in drawn_values]


def observed_dependent_deterministics(model: Model):
"""Find deterministics that depend directly on observed variables"""
deterministics = model.deterministics
observed_rvs = set(model.observed_RVs)
blockers = model.basic_RVs
return [
deterministic
for deterministic in deterministics
if observed_rvs & set(ancestors([deterministic], blockers=blockers))
]


def sample_prior_predictive(
samples: int = 500,
model: Optional[Model] = None,
Expand Down Expand Up @@ -371,10 +390,8 @@ def sample_prior_predictive(
)

if var_names is None:
prior_pred_vars = model.observed_RVs + model.auto_deterministics
prior_vars = (
get_default_varnames(model.unobserved_RVs, include_transformed=True) + model.potentials
)
prior_pred_vars = model.observed_RVs
prior_vars = get_default_varnames(model.unobserved_RVs, include_transformed=True)
vars_: Set[str] = {var.name for var in prior_vars + prior_pred_vars}
else:
vars_ = set(var_names)
Expand Down Expand Up @@ -571,7 +588,7 @@ def sample_posterior_predictive(
if var_names is not None:
vars_ = [model[x] for x in var_names]
else:
vars_ = model.observed_RVs + model.auto_deterministics
vars_ = model.observed_RVs + observed_dependent_deterministics(model)

indices = np.arange(samples)
if progressbar:
Expand Down
17 changes: 17 additions & 0 deletions pymc/tests/sampling/test_forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,7 @@
from pymc.sampling.forward import (
compile_forward_sampling_function,
get_vars_in_point_list,
observed_dependent_deterministics,
)
from pymc.tests.helpers import SeededTest, fast_unstable_sampling_mode

Expand Down Expand Up @@ -1621,3 +1622,19 @@ def test_get_vars_in_point_list():
trace = MultiTrace([strace])
vars_in_trace = get_vars_in_point_list(trace, modelB)
assert set(vars_in_trace) == {a}


def test_observed_dependent_deterministics():
with pm.Model() as m:
free = pm.Normal("free")
obs = pm.Normal("obs", observed=1)

det_free = pm.Deterministic("det_free", free + 1)
det_free2 = pm.Deterministic("det_free2", det_free + 1)

det_obs = pm.Deterministic("det_obs", obs + 1)
det_obs2 = pm.Deterministic("det_obs2", det_obs + 1)

det_mixed = pm.Deterministic("det_mixed", free + obs)

assert set(observed_dependent_deterministics(m)) == {det_obs, det_obs2, det_mixed}
58 changes: 45 additions & 13 deletions pymc/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -1195,22 +1195,30 @@ def test_missing_dual_observations(self):
trace = pm.sample(chains=1, tune=5, draws=50)

def test_interval_missing_observations(self):
rng = np.random.default_rng(1198)

with pm.Model() as model:
obs1 = np.ma.masked_values([1, 2, -1, 4, -1], value=-1)
obs2 = np.ma.masked_values([-1, -1, 6, -1, 8], value=-1)

rng = aesara.shared(np.random.RandomState(2323), borrow=True)

with pytest.warns(ImputationWarning):
theta1 = pm.Uniform("theta1", 0, 5, observed=obs1, rng=rng)
theta1 = pm.Uniform("theta1", 0, 5, observed=obs1)
with pytest.warns(ImputationWarning):
theta2 = pm.Normal("theta2", mu=theta1, observed=obs2, rng=rng)
theta2 = pm.Normal("theta2", mu=theta1, observed=obs2)

assert "theta1_observed" in model.named_vars
assert "theta1_missing_interval__" in model.named_vars
assert model.rvs_to_transforms[model.named_vars["theta1_observed"]] is None

prior_trace = pm.sample_prior_predictive(return_inferencedata=False)
prior_trace = pm.sample_prior_predictive(random_seed=rng, return_inferencedata=False)
assert set(prior_trace.keys()) == {
"theta1",
"theta1_observed",
"theta1_missing",
"theta2",
"theta2_observed",
"theta2_missing",
}

# Make sure the observed + missing combined deterministics have the
# same shape as the original observations vectors
Expand Down Expand Up @@ -1238,23 +1246,47 @@ def test_interval_missing_observations(self):
== 0.0
)

assert {"theta1", "theta2"} <= set(prior_trace.keys())

trace = pm.sample(
chains=1, draws=50, compute_convergence_checks=False, return_inferencedata=False
chains=1,
draws=50,
compute_convergence_checks=False,
return_inferencedata=False,
random_seed=rng,
)
assert set(trace.varnames) == {
"theta1",
"theta1_missing",
"theta1_missing_interval__",
"theta2",
"theta2_missing",
}

# Make sure that the missing values are newly generated samples and that
# the observed and deterministic match
assert np.all(0 < trace["theta1_missing"].mean(0))
assert np.all(0 < trace["theta2_missing"].mean(0))
assert "theta1" not in trace.varnames
assert "theta2" not in trace.varnames
assert np.isclose(np.mean(trace["theta1"][:, obs1.mask] - trace["theta1_missing"]), 0)
assert np.isclose(np.mean(trace["theta2"][:, obs2.mask] - trace["theta2_missing"]), 0)

# Make sure that the observed values are newly generated samples and that
# the observed and deterministic matche
pp_idata = pm.sample_posterior_predictive(trace)
# Make sure that the observed values are unchanged
assert np.allclose(np.var(trace["theta1"][:, ~obs1.mask], 0), 0.0)
assert np.allclose(np.var(trace["theta2"][:, ~obs2.mask], 0), 0.0)
np.testing.assert_array_equal(trace["theta1"][0][~obs1.mask], obs1[~obs1.mask])
np.testing.assert_array_equal(trace["theta2"][0][~obs2.mask], obs1[~obs2.mask])

pp_idata = pm.sample_posterior_predictive(trace, random_seed=rng)
pp_trace = pp_idata.posterior_predictive.stack(sample=["chain", "draw"]).transpose(
"sample", ...
)
assert set(pp_trace.keys()) == {
"theta1",
"theta1_observed",
"theta2",
"theta2_observed",
}

# Make sure that the observed values are newly generated samples and that
# the observed and deterministic match
assert np.all(np.var(pp_trace["theta1"], 0) > 0.0)
assert np.all(np.var(pp_trace["theta2"], 0) > 0.0)
assert np.isclose(
Expand Down

0 comments on commit 0121fd9

Please sign in to comment.