Skip to content

Commit

Permalink
Deprecate sample_posterior_predictive_w (#6254)
Browse files Browse the repository at this point in the history
Closes #4807
  • Loading branch information
zaxtax authored Oct 30, 2022
1 parent bcffce2 commit 9105d74
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 186 deletions.
125 changes: 4 additions & 121 deletions pymc/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -2065,127 +2065,10 @@ def sample_posterior_predictive_w(
weighted models (default), or a dictionary with variable names as keys, and samples as
numpy arrays.
"""
raise NotImplementedError(f"sample_posterior_predictive_w has not yet been ported to PyMC 4.0.")

if isinstance(traces[0], InferenceData):
n_samples = [
trace.posterior.sizes["chain"] * trace.posterior.sizes["draw"] for trace in traces
]
traces = [dataset_to_point_list(trace.posterior) for trace in traces]
elif isinstance(traces[0], xarray.Dataset):
n_samples = [trace.sizes["chain"] * trace.sizes["draw"] for trace in traces]
traces = [dataset_to_point_list(trace) for trace in traces]
else:
n_samples = [len(i) * i.nchains for i in traces]

if models is None:
models = [modelcontext(models)] * len(traces)

if random_seed is not None:
(random_seed,) = _get_seeds_per_chain(random_seed, 1)

for model in models:
if model.potentials:
warnings.warn(
"The effect of Potentials on other parameters is ignored during posterior predictive sampling. "
"This is likely to lead to invalid or biased predictive samples.",
UserWarning,
stacklevel=2,
)
break

if weights is None:
weights = [1] * len(traces)

if len(traces) != len(weights):
raise ValueError("The number of traces and weights should be the same")

if len(models) != len(weights):
raise ValueError("The number of models and weights should be the same")

length_morv = len(models[0].observed_RVs)
if any(len(i.observed_RVs) != length_morv for i in models):
raise ValueError("The number of observed RVs should be the same for all models")

weights = np.asarray(weights)
p = weights / np.sum(weights)

min_tr = min(n_samples)

n = (min_tr * p).astype("int")
# ensure n sum up to min_tr
idx = np.argmax(n)
n[idx] = n[idx] + min_tr - np.sum(n)
trace = []
for i, j in enumerate(n):
tr = traces[i]
len_trace = len(tr)
try:
nchain = tr.nchains
except AttributeError:
nchain = 1

indices = np.random.randint(0, nchain * len_trace, j)
if nchain > 1:
chain_idx, point_idx = np.divmod(indices, len_trace)
for cidx, pidx in zip(chain_idx, point_idx):
trace.append(tr._straces[cidx].point(pidx))
else:
for idx in indices:
trace.append(tr[idx])

obs = [x for m in models for x in m.observed_RVs]
variables = np.repeat(obs, n)

lengths = list({np.atleast_1d(observed).shape for observed in obs})

size: List[Optional[Tuple[int, ...]]] = []
if len(lengths) == 1:
size = [None] * len(variables)
elif len(lengths) > 2:
raise ValueError("Observed variables could not be broadcast together")
else:
x = np.zeros(shape=lengths[0])
y = np.zeros(shape=lengths[1])
b = np.broadcast(x, y)
for var in variables:
# XXX: This needs to be refactored
shape = None # np.shape(np.atleast_1d(var.distribution.default()))
if shape != b.shape:
size.append(b.shape)
else:
size.append(None)
len_trace = len(trace)

if samples is None:
samples = len_trace

indices = np.random.randint(0, len_trace, samples)

if progressbar:
indices = progress_bar(indices, total=samples, display=progressbar)

try:
ppcl: Dict[str, list] = defaultdict(list)
for idx in indices:
param = trace[idx]
var = variables[idx]
# TODO sample_posterior_predictive_w is currently only work for model with
# one observed.
# XXX: This needs to be refactored
# ppc[var.name].append(draw_values([var], point=param, size=size[idx])[0])
raise NotImplementedError()

except KeyboardInterrupt:
pass
else:
ppcd = {k: np.asarray(v) for k, v in ppcl.items()}
if not return_inferencedata:
return ppcd
ikwargs: Dict[str, Any] = dict(model=models)
if idata_kwargs:
ikwargs.update(idata_kwargs)
return pm.to_inference_data(posterior_predictive=ppcd, **ikwargs)
raise FutureWarning(
"The function `sample_posterior_predictive_w` has been removed in PyMC 4.3.0. "
"Switch to `arviz.stats.weight_predictions`"
)


def sample_prior_predictive(
Expand Down
65 changes: 0 additions & 65 deletions pymc/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -1177,71 +1177,6 @@ def test_logging_sampled_basic_rvs_posterior_mutable(self, mock_sample_results,
caplog.clear()


@pytest.mark.xfail(
reason="sample_posterior_predictive_w not refactored for v4", raises=NotImplementedError
)
class TestSamplePPCW(SeededTest):
def test_sample_posterior_predictive_w(self):
data0 = np.random.normal(0, 1, size=50)
warning_msg = "The number of samples is too small to check convergence reliably"

with pm.Model() as model_0:
mu = pm.Normal("mu", mu=0, sigma=1)
y = pm.Normal("y", mu=mu, sigma=1, observed=data0)
with pytest.warns(UserWarning, match=warning_msg):
trace_0 = pm.sample(10, tune=0, chains=2, return_inferencedata=False)
idata_0 = pm.to_inference_data(trace_0, log_likelihood=False)

with pm.Model() as model_1:
mu = pm.Normal("mu", mu=0, sigma=1, size=len(data0))
y = pm.Normal("y", mu=mu, sigma=1, observed=data0)
with pytest.warns(UserWarning, match=warning_msg):
trace_1 = pm.sample(10, tune=0, chains=2, return_inferencedata=False)
idata_1 = pm.to_inference_data(trace_1, log_likelihood=False)

with pm.Model() as model_2:
# Model with no observed RVs.
mu = pm.Normal("mu", mu=0, sigma=1)
with pytest.warns(UserWarning, match=warning_msg):
trace_2 = pm.sample(10, tune=0, return_inferencedata=False)

traces = [trace_0, trace_1]
idatas = [idata_0, idata_1]
models = [model_0, model_1]

ppc = pm.sample_posterior_predictive_w(traces, 100, models)
assert ppc["y"].shape == (100, 50)

ppc = pm.sample_posterior_predictive_w(idatas, 100, models)
assert ppc["y"].shape == (100, 50)

with model_0:
ppc = pm.sample_posterior_predictive_w([idata_0.posterior], None)
assert ppc["y"].shape == (20, 50)

with pytest.raises(ValueError, match="The number of traces and weights should be the same"):
pm.sample_posterior_predictive_w([idata_0.posterior], 100, models, weights=[0.5, 0.5])

with pytest.raises(ValueError, match="The number of models and weights should be the same"):
pm.sample_posterior_predictive_w([idata_0.posterior], 100, models)

with pytest.raises(
ValueError, match="The number of observed RVs should be the same for all models"
):
pm.sample_posterior_predictive_w([trace_0, trace_2], 100, [model_0, model_2])

def test_potentials_warning(self):
warning_msg = "The effect of Potentials on other parameters is ignored during"
with pm.Model() as m:
a = pm.Normal("a", 0, 1)
p = pm.Potential("p", a + 1)
obs = pm.Normal("obs", a, 1, observed=5)

trace = az_from_dict({"a": np.random.rand(10)})
with pytest.warns(UserWarning, match=warning_msg):
pm.sample_posterior_predictive_w(samples=5, traces=[trace, trace], models=[m, m])


def check_exec_nuts_init(method):
with pm.Model() as model:
pm.Normal("a", mu=0, sigma=1, size=2)
Expand Down

0 comments on commit 9105d74

Please sign in to comment.