Skip to content

Commit

Permalink
Michaelosthege/issue 3828 (#2)
Browse files Browse the repository at this point in the history
* xarray test for fast posterior predictive sampling.

* Move Dataset translation to util.

The translation from xarray Dataset to a list of points was previously open-coded into sample_posterior_predictive.  Pulled it out so it can be used in both spp and fast_sample_posterior_predictive.

* fast_sample_posterior_predictive support for xarray traces.
  • Loading branch information
rpgoldman authored Mar 20, 2020
1 parent 796c9bb commit fc09a79
Show file tree
Hide file tree
Showing 6 changed files with 1,292 additions and 18 deletions.
1,059 changes: 1,059 additions & 0 deletions docs/source/build.out

Large diffs are not rendered by default.

192 changes: 192 additions & 0 deletions docs/source/notebooks/Prior predictive bug.ipynb
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
{
"cells": [
{
"cell_type": "code",
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
"import pymc3"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {},
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"Auto-assigning NUTS sampler...\n",
"Initializing NUTS using jitter+adapt_diag...\n",
"Multiprocess sampling (4 chains in 4 jobs)\n",
"NUTS: [n]\n"
]
},
{
"data": {
"text/html": [
"\n",
" <div>\n",
" <style>\n",
" /* Turns off some styling */\n",
" progress {\n",
" /* gets rid of default border in Firefox and Opera. */\n",
" border: none;\n",
" /* Needs to be in here for Safari polyfill so background images work as expected. */\n",
" background-size: auto;\n",
" }\n",
" .progress-bar-interrupted, .progress-bar-interrupted::-webkit-progress-bar {\n",
" background: #F44336;\n",
" }\n",
" </style>\n",
" <progress value='4000' class='' max='4000', style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
" 100.00% [4000/4000 00:00<00:00 Sampling 4 chains, 0 divergences]\n",
" </div>\n",
" "
],
"text/plain": [
"<IPython.core.display.HTML object>"
]
},
"metadata": {},
"output_type": "display_data"
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"The acceptance probability does not match the target. It is 0.8816556891941705, but should be close to 0.8. Try to increase the number of tuning steps.\n"
]
}
],
"source": [
"with pymc3.Model() as pmodel:\n",
" n = pymc3.Normal('n')\n",
" trace = pymc3.sample()\n",
"\n",
"with pmodel:\n",
" d = pymc3.Deterministic('d', n * 4)"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"CPU times: user 18.8 ms, sys: 3.42 ms, total: 22.2 ms\n",
"Wall time: 23.8 ms\n"
]
}
],
"source": [
"%%time\n",
"with pmodel:\n",
" pp = pymc3.fast_sample_posterior_predictive(\n",
" [trace[15]],\n",
" var_names=['d']\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {},
"outputs": [
{
"ename": "AttributeError",
"evalue": "'list' object has no attribute '_straces'",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)",
"\u001b[0;32m<timed exec>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m\u001b[0m\n",
"\u001b[0;32m~/src/pymc3/pymc3/sampling.py\u001b[0m in \u001b[0;36msample_posterior_predictive\u001b[0;34m(trace, samples, model, vars, var_names, size, keep_size, random_seed, progressbar)\u001b[0m\n\u001b[1;32m 1539\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1540\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msamples\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1541\u001b[0;31m \u001b[0msamples\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0msum\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mlen\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mfor\u001b[0m \u001b[0mv\u001b[0m \u001b[0;32min\u001b[0m \u001b[0mtrace\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_straces\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mvalues\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1542\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1543\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0msamples\u001b[0m \u001b[0;34m<\u001b[0m \u001b[0mlen_trace\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mnchain\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
"\u001b[0;31mAttributeError\u001b[0m: 'list' object has no attribute '_straces'"
]
}
],
"source": [
"%%time\n",
"with pmodel:\n",
" pp = pymc3.sample_posterior_predictive(\n",
" [trace[15]],\n",
" var_names=['d']\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"{'n': 0.691903087470128}"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trace[15]"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"True"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"'MultiTrace' in dir(pymc3.backends.base)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
}
},
"nbformat": 4,
"nbformat_minor": 2
}
9 changes: 7 additions & 2 deletions pymc3/distributions/posterior_predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,12 +12,14 @@
import numpy as np
import theano
import theano.tensor as tt
from xarray import Dataset

from ..backends.base import MultiTrace #, TraceLike, TraceDict
from .distribution import _DrawValuesContext, _DrawValuesContextBlocker, is_fast_drawable, _compile_theano_function, vectorized_ppc
from ..model import Model, get_named_nodes_and_relations, ObservedRV, MultiObservedRV, modelcontext
from ..exceptions import IncorrectArgumentsError
from ..vartypes import theano_constant
from ..util import dataset_to_point_dict
# Failing tests:
# test_mixture_random_shape::test_mixture_random_shape
#
Expand Down Expand Up @@ -119,7 +121,7 @@ def __getitem__(self, item):



def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.ndarray]]],
def fast_sample_posterior_predictive(trace: Union[MultiTrace, Dataset, List[Dict[str, np.ndarray]]],
samples: Optional[int]=None,
model: Optional[Model]=None,
var_names: Optional[List[str]]=None,
Expand All @@ -135,7 +137,7 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
Parameters
----------
trace : MultiTrace or List of points
trace : MultiTrace, xarray.Dataset, or List of points (dictionary)
Trace generated from MCMC sampling.
samples : int, optional
Number of posterior predictive samples to generate. Defaults to one posterior predictive
Expand Down Expand Up @@ -168,6 +170,9 @@ def fast_sample_posterior_predictive(trace: Union[MultiTrace, List[Dict[str, np.
### greater than the number of samples in the trace parameter, we sample repeatedly. This
### makes the shape issues just a little easier to deal with.

if isinstance(trace, Dataset):
trace = dataset_to_point_dict(trace)

model = modelcontext(model)
assert model is not None
with model:
Expand Down
17 changes: 2 additions & 15 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,7 @@
get_untransformed_name,
is_transformed_name,
get_default_varnames,
dataset_to_point_dict,
)
from .vartypes import discrete_types
from .exceptions import IncorrectArgumentsError
Expand Down Expand Up @@ -1558,21 +1559,7 @@ def sample_posterior_predictive(
posterior predictive samples.
"""
if isinstance(trace, xarray.Dataset):
# grab posterior samples for each variable
_samples = {
vn : trace[vn].values
for vn in trace.keys()
}
# make dicts
points = []
for c in trace.chain:
for d in trace.draw:
points.append({
vn : s[c, d]
for vn, s in _samples.items()
})
# use the list of points
trace = points
trace = dataset_to_point_dict(trace)

len_trace = len(trace)
try:
Expand Down
9 changes: 9 additions & 0 deletions pymc3/tests/test_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -901,3 +901,12 @@ def test_sample_from_xarray_posterior(self, point_list_arg_bug_fixture):
idat.posterior,
var_names=['d']
)

def test_sample_from_xarray_posterior_fast(self, point_list_arg_bug_fixture):
pmodel, trace = point_list_arg_bug_fixture
idat = az.from_pymc3(trace)
with pmodel:
pp = pm.fast_sample_posterior_predictive(
idat.posterior,
var_names=['d']
)
24 changes: 23 additions & 1 deletion pymc3/util.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,11 @@

import re
import functools
from numpy import asscalar
from typing import List, Dict

import xarray
from numpy import asscalar, ndarray


LATEX_ESCAPE_RE = re.compile(r'(%|_|\$|#|&)', re.MULTILINE)

Expand Down Expand Up @@ -179,3 +183,21 @@ def enhanced(*args, **kwargs):
newwrapper = functools.partial(wrapper, *args, **kwargs)
return newwrapper
return enhanced

def dataset_to_point_dict(ds: xarray.Dataset) -> List[Dict[str, ndarray]]:
# grab posterior samples for each variable
_samples = {
vn : ds[vn].values
for vn in ds.keys()
}
# make dicts
points = []
for c in ds.chain:
for d in ds.draw:
points.append({
vn : s[c, d]
for vn, s in _samples.items()
})
# use the list of points
ds = points
return ds

0 comments on commit fc09a79

Please sign in to comment.