Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add compute_deterministics helper #7238

Merged
merged 5 commits into from
Apr 6, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions .github/workflows/tests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ jobs:
- |
tests/distributions/test_censored.py
tests/distributions/test_simulator.py
tests/sampling/test_deterministic.py
tests/sampling/test_forward.py
tests/sampling/test_population.py
tests/stats/test_convergence.py
Expand Down
22 changes: 5 additions & 17 deletions docs/source/api/samplers.rst
Original file line number Diff line number Diff line change
Expand Up @@ -4,31 +4,19 @@ Samplers
This submodule contains functions for MCMC and forward sampling.


.. currentmodule:: pymc.sampling.forward
.. currentmodule:: pymc

.. autosummary::
:toctree: generated/

sample
sample_prior_predictive
sample_posterior_predictive
draw


.. currentmodule:: pymc.sampling.mcmc

.. autosummary::
:toctree: generated/

sample
compute_deterministics
init_nuts

.. currentmodule:: pymc.sampling.jax

.. autosummary::
:toctree: generated/

sample_blackjax_nuts
sample_numpyro_nuts
sampling.jax.sample_blackjax_nuts
sampling.jax.sample_numpyro_nuts


Step methods
Expand Down
80 changes: 77 additions & 3 deletions pymc/backends/arviz.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,26 +16,31 @@
import logging
import warnings

from collections.abc import Iterable, Mapping
from collections.abc import Iterable, Mapping, Sequence
from typing import (
TYPE_CHECKING,
Any,
Optional,
Union,
cast,
)

import numpy as np
import xarray

from arviz import InferenceData, concat, rcParams
from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires
from pytensor.graph.basic import Constant
from pytensor.tensor.sharedvar import SharedVariable
from rich.progress import Console, Progress
from rich.theme import Theme
from xarray import Dataset

import pymc

from pymc.model import Model, modelcontext
from pymc.pytensorf import extract_obs_data
from pymc.util import get_default_varnames
from pymc.pytensorf import PointFunc, extract_obs_data
from pymc.util import default_progress_theme, get_default_varnames

if TYPE_CHECKING:
from pymc.backends.base import MultiTrace
Expand Down Expand Up @@ -612,3 +617,72 @@ def predictions_to_inference_data(
# data and return that.
concat([new_idata, idata_orig], dim=None, copy=True, inplace=True)
return new_idata


def dataset_to_point_list(
ds: xarray.Dataset | dict[str, xarray.DataArray], sample_dims: Sequence[str]
) -> tuple[list[dict[str, np.ndarray]], dict[str, Any]]:
# All keys of the dataset must be a str
var_names = cast(list[str], list(ds.keys()))
for vn in var_names:
if not isinstance(vn, str):
raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.")
num_sample_dims = len(sample_dims)
stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims}
transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()}
stacked_dict = {
vn: da.values.reshape((-1, *da.shape[num_sample_dims:]))
for vn, da in transposed_dict.items()
}
points = [
{vn: stacked_dict[vn][i, ...] for vn in var_names}
for i in range(np.prod([len(coords) for coords in stacked_dims.values()]))
]
# use the list of points
return cast(list[dict[str, np.ndarray]], points), stacked_dims


def apply_function_over_dataset(
fn: PointFunc,
dataset: Dataset,
*,
output_var_names: Sequence[str],
coords,
dims,
sample_dims: Sequence[str] = ("chain", "draw"),
progressbar: bool = True,
progressbar_theme: Theme | None = default_progress_theme,
) -> Dataset:
posterior_pts, stacked_dims = dataset_to_point_list(dataset, sample_dims)

n_pts = len(posterior_pts)
out_dict = _DefaultTrace(n_pts)
indices = range(n_pts)

with Progress(console=Console(theme=progressbar_theme)) as progress:
task = progress.add_task("Computinng ...", total=n_pts, visible=progressbar)
for idx in indices:
out = fn(posterior_pts[idx])
fn.f.trust_input = True # If we arrive here the dtypes are valid
for var_name, val in zip(output_var_names, out):
out_dict.insert(var_name, val, idx)

progress.advance(task)

out_trace = out_dict.trace_dict
for key, val in out_trace.items():
out_trace[key] = val.reshape(
(
*[len(coord) for coord in stacked_dims.values()],
*val.shape[1:],
)
)

return dict_to_dataset(
out_trace,
library=pymc,
dims=dims,
coords=coords,
default_dims=list(sample_dims),
skip_event_dims=True,
)
31 changes: 27 additions & 4 deletions pymc/model/core.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,7 @@
import types
import warnings

from collections.abc import Callable, Iterable, Sequence
from collections.abc import Iterable, Sequence
from sys import modules
from typing import (
TYPE_CHECKING,
Expand All @@ -27,6 +27,7 @@
Optional,
TypeVar,
cast,
overload,
)

import numpy as np
Expand All @@ -35,7 +36,7 @@
import pytensor.tensor as pt
import scipy.sparse as sps

from pytensor.compile import DeepCopyOp, get_mode
from pytensor.compile import DeepCopyOp, Function, get_mode
from pytensor.compile.sharedvalue import SharedVariable
from pytensor.graph.basic import Constant, Variable, graph_inputs
from pytensor.scalar import Cast
Expand Down Expand Up @@ -1524,6 +1525,28 @@ def replace_rvs_by_values(
rvs_to_transforms=self.rvs_to_transforms,
)

@overload
def compile_fn(
self,
outs: Variable | Sequence[Variable],
*,
inputs: Sequence[Variable] | None = None,
mode=None,
point_fn: Literal[True] = True,
**kwargs,
) -> PointFunc: ...

@overload
def compile_fn(
self,
outs: Variable | Sequence[Variable],
*,
inputs: Sequence[Variable] | None = None,
mode=None,
point_fn: Literal[False],
**kwargs,
) -> Function: ...

def compile_fn(
self,
outs: Variable | Sequence[Variable],
Expand All @@ -1532,7 +1555,7 @@ def compile_fn(
mode=None,
point_fn: bool = True,
**kwargs,
) -> PointFunc | Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]:
) -> PointFunc | Function:
"""Compiles an PyTensor function

Parameters
Expand Down Expand Up @@ -2044,7 +2067,7 @@ def compile_fn(
point_fn: bool = True,
model: Model | None = None,
**kwargs,
) -> PointFunc | Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]:
) -> PointFunc | Function:
"""Compiles an PyTensor function

Parameters
Expand Down
1 change: 1 addition & 0 deletions pymc/sampling/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,5 +12,6 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from pymc.sampling.deterministic import compute_deterministics
from pymc.sampling.forward import *
from pymc.sampling.mcmc import *
114 changes: 114 additions & 0 deletions pymc/sampling/deterministic.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,114 @@
# Copyright 2024 The PyMC Developers
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from collections.abc import Sequence

import xarray

from xarray import Dataset

from pymc.backends.arviz import apply_function_over_dataset, coords_and_dims_for_inferencedata
from pymc.model.core import Model, modelcontext


def compute_deterministics(
dataset: Dataset,
*,
var_names: Sequence[str] | None = None,
model: Model | None = None,
sample_dims: Sequence[str] = ("chain", "draw"),
merge_dataset: bool = False,
progressbar: bool = True,
compile_kwargs: dict | None = None,
) -> Dataset:
"""Compute model deterministics given a dataset with values for model variables.

Parameters
----------
dataset : Dataset
Dataset with values for model variables. Commonly InferenceData["posterior"].
var_names : sequence of str, optional
List of names of deterministic variable to compute.
If None, compute all deterministics in the model.
model : Model, optional
Model to use. If None, use context model.
sample_dims : sequence of str, default ("chain", "draw")
Sample (batch) dimensions of the dataset over which to compute the deterministics.
merge_dataset : bool, default False
Whether to extend the original dataset or return a new one.
progressbar : bool, default True
Whether to display a progress bar in the command line.
progressbar_theme : Theme, optional
Custom theme for the progress bar.
compile_kwargs: dict, optional
Additional arguments passed to `model.compile_fn`.

Returns
-------
Dataset
Dataset with values for the deterministics.


Examples
--------
.. code:: python

import pymc as pm

with pm.Model(coords={"group": (0, 2, 4)}) as m:
mu_raw = pm.Normal("mu_raw", 0, 1, dims="group")
mu = pm.Deterministic("mu", mu_raw.cumsum(), dims="group")

trace = pm.sample(var_names=["mu_raw"], chains=2, tune=5 draws=5)

assert "mu" not in trace.posterior

with m:
trace.posterior = pm.compute_deterministics(trace.posterior, merge_dataset=True)

assert "mu" in trace.posterior


"""
model = modelcontext(model)

if var_names is None:
deterministics = model.deterministics
else:
deterministics = [model[var_name] for var_name in var_names]
if not set(deterministics).issubset(set(model.deterministics)):
raise ValueError("Not all var_names corresponded to model deterministics")

fn = model.compile_fn(
inputs=model.free_RVs,
outs=deterministics,
on_unused_input="ignore",
**(compile_kwargs or {}),
)

coords, dims = coords_and_dims_for_inferencedata(model)

new_dataset = apply_function_over_dataset(
fn,
dataset[[rv.name for rv in model.free_RVs]],
output_var_names=[det.name for det in model.deterministics],
dims=dims,
coords=coords,
sample_dims=sample_dims,
progressbar=progressbar,
)

if merge_dataset:
new_dataset = xarray.merge([dataset, new_dataset], compat="override")

return new_dataset
3 changes: 1 addition & 2 deletions pymc/sampling/forward.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,15 +49,14 @@

import pymc as pm

from pymc.backends.arviz import _DefaultTrace
from pymc.backends.arviz import _DefaultTrace, dataset_to_point_list
from pymc.backends.base import MultiTrace
from pymc.blocking import PointType
from pymc.model import Model, modelcontext
from pymc.pytensorf import compile_pymc
from pymc.util import (
RandomState,
_get_seeds_per_chain,
dataset_to_point_list,
default_progress_theme,
get_default_varnames,
point_wrapper,
Expand Down
Loading
Loading