diff --git a/.github/workflows/tests.yml b/.github/workflows/tests.yml index 55ccc3d2e02..155d24aa3f9 100644 --- a/.github/workflows/tests.yml +++ b/.github/workflows/tests.yml @@ -69,6 +69,7 @@ jobs: - | tests/distributions/test_censored.py tests/distributions/test_simulator.py + tests/sampling/test_deterministic.py tests/sampling/test_forward.py tests/sampling/test_population.py tests/stats/test_convergence.py diff --git a/docs/source/api/samplers.rst b/docs/source/api/samplers.rst index a62ffb285ed..5a7caa0c739 100644 --- a/docs/source/api/samplers.rst +++ b/docs/source/api/samplers.rst @@ -4,31 +4,19 @@ Samplers This submodule contains functions for MCMC and forward sampling. -.. currentmodule:: pymc.sampling.forward +.. currentmodule:: pymc .. autosummary:: :toctree: generated/ + sample sample_prior_predictive sample_posterior_predictive draw - - -.. currentmodule:: pymc.sampling.mcmc - -.. autosummary:: - :toctree: generated/ - - sample + compute_deterministics init_nuts - -.. currentmodule:: pymc.sampling.jax - -.. autosummary:: - :toctree: generated/ - - sample_blackjax_nuts - sample_numpyro_nuts + sampling.jax.sample_blackjax_nuts + sampling.jax.sample_numpyro_nuts Step methods diff --git a/pymc/backends/arviz.py b/pymc/backends/arviz.py index 4f32fe65cf0..3af240725fd 100644 --- a/pymc/backends/arviz.py +++ b/pymc/backends/arviz.py @@ -16,26 +16,31 @@ import logging import warnings -from collections.abc import Iterable, Mapping +from collections.abc import Iterable, Mapping, Sequence from typing import ( TYPE_CHECKING, Any, Optional, Union, + cast, ) import numpy as np +import xarray from arviz import InferenceData, concat, rcParams from arviz.data.base import CoordSpec, DimSpec, dict_to_dataset, requires from pytensor.graph.basic import Constant from pytensor.tensor.sharedvar import SharedVariable +from rich.progress import Console, Progress +from rich.theme import Theme +from xarray import Dataset import pymc from pymc.model import Model, modelcontext -from pymc.pytensorf import extract_obs_data -from pymc.util import get_default_varnames +from pymc.pytensorf import PointFunc, extract_obs_data +from pymc.util import default_progress_theme, get_default_varnames if TYPE_CHECKING: from pymc.backends.base import MultiTrace @@ -612,3 +617,72 @@ def predictions_to_inference_data( # data and return that. concat([new_idata, idata_orig], dim=None, copy=True, inplace=True) return new_idata + + +def dataset_to_point_list( + ds: xarray.Dataset | dict[str, xarray.DataArray], sample_dims: Sequence[str] +) -> tuple[list[dict[str, np.ndarray]], dict[str, Any]]: + # All keys of the dataset must be a str + var_names = cast(list[str], list(ds.keys())) + for vn in var_names: + if not isinstance(vn, str): + raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.") + num_sample_dims = len(sample_dims) + stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims} + transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()} + stacked_dict = { + vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) + for vn, da in transposed_dict.items() + } + points = [ + {vn: stacked_dict[vn][i, ...] for vn in var_names} + for i in range(np.prod([len(coords) for coords in stacked_dims.values()])) + ] + # use the list of points + return cast(list[dict[str, np.ndarray]], points), stacked_dims + + +def apply_function_over_dataset( + fn: PointFunc, + dataset: Dataset, + *, + output_var_names: Sequence[str], + coords, + dims, + sample_dims: Sequence[str] = ("chain", "draw"), + progressbar: bool = True, + progressbar_theme: Theme | None = default_progress_theme, +) -> Dataset: + posterior_pts, stacked_dims = dataset_to_point_list(dataset, sample_dims) + + n_pts = len(posterior_pts) + out_dict = _DefaultTrace(n_pts) + indices = range(n_pts) + + with Progress(console=Console(theme=progressbar_theme)) as progress: + task = progress.add_task("Computinng ...", total=n_pts, visible=progressbar) + for idx in indices: + out = fn(posterior_pts[idx]) + fn.f.trust_input = True # If we arrive here the dtypes are valid + for var_name, val in zip(output_var_names, out): + out_dict.insert(var_name, val, idx) + + progress.advance(task) + + out_trace = out_dict.trace_dict + for key, val in out_trace.items(): + out_trace[key] = val.reshape( + ( + *[len(coord) for coord in stacked_dims.values()], + *val.shape[1:], + ) + ) + + return dict_to_dataset( + out_trace, + library=pymc, + dims=dims, + coords=coords, + default_dims=list(sample_dims), + skip_event_dims=True, + ) diff --git a/pymc/model/core.py b/pymc/model/core.py index cac340f7f42..eee8f2a904e 100644 --- a/pymc/model/core.py +++ b/pymc/model/core.py @@ -18,7 +18,7 @@ import types import warnings -from collections.abc import Callable, Iterable, Sequence +from collections.abc import Iterable, Sequence from sys import modules from typing import ( TYPE_CHECKING, @@ -27,6 +27,7 @@ Optional, TypeVar, cast, + overload, ) import numpy as np @@ -35,7 +36,7 @@ import pytensor.tensor as pt import scipy.sparse as sps -from pytensor.compile import DeepCopyOp, get_mode +from pytensor.compile import DeepCopyOp, Function, get_mode from pytensor.compile.sharedvalue import SharedVariable from pytensor.graph.basic import Constant, Variable, graph_inputs from pytensor.scalar import Cast @@ -1524,6 +1525,28 @@ def replace_rvs_by_values( rvs_to_transforms=self.rvs_to_transforms, ) + @overload + def compile_fn( + self, + outs: Variable | Sequence[Variable], + *, + inputs: Sequence[Variable] | None = None, + mode=None, + point_fn: Literal[True] = True, + **kwargs, + ) -> PointFunc: ... + + @overload + def compile_fn( + self, + outs: Variable | Sequence[Variable], + *, + inputs: Sequence[Variable] | None = None, + mode=None, + point_fn: Literal[False], + **kwargs, + ) -> Function: ... + def compile_fn( self, outs: Variable | Sequence[Variable], @@ -1532,7 +1555,7 @@ def compile_fn( mode=None, point_fn: bool = True, **kwargs, - ) -> PointFunc | Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]: + ) -> PointFunc | Function: """Compiles an PyTensor function Parameters @@ -2044,7 +2067,7 @@ def compile_fn( point_fn: bool = True, model: Model | None = None, **kwargs, -) -> PointFunc | Callable[[Sequence[np.ndarray]], Sequence[np.ndarray]]: +) -> PointFunc | Function: """Compiles an PyTensor function Parameters diff --git a/pymc/sampling/__init__.py b/pymc/sampling/__init__.py index 8e854b7f5fd..547250cd58b 100644 --- a/pymc/sampling/__init__.py +++ b/pymc/sampling/__init__.py @@ -12,5 +12,6 @@ # See the License for the specific language governing permissions and # limitations under the License. +from pymc.sampling.deterministic import compute_deterministics from pymc.sampling.forward import * from pymc.sampling.mcmc import * diff --git a/pymc/sampling/deterministic.py b/pymc/sampling/deterministic.py new file mode 100644 index 00000000000..b300d5ee976 --- /dev/null +++ b/pymc/sampling/deterministic.py @@ -0,0 +1,114 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +from collections.abc import Sequence + +import xarray + +from xarray import Dataset + +from pymc.backends.arviz import apply_function_over_dataset, coords_and_dims_for_inferencedata +from pymc.model.core import Model, modelcontext + + +def compute_deterministics( + dataset: Dataset, + *, + var_names: Sequence[str] | None = None, + model: Model | None = None, + sample_dims: Sequence[str] = ("chain", "draw"), + merge_dataset: bool = False, + progressbar: bool = True, + compile_kwargs: dict | None = None, +) -> Dataset: + """Compute model deterministics given a dataset with values for model variables. + + Parameters + ---------- + dataset : Dataset + Dataset with values for model variables. Commonly InferenceData["posterior"]. + var_names : sequence of str, optional + List of names of deterministic variable to compute. + If None, compute all deterministics in the model. + model : Model, optional + Model to use. If None, use context model. + sample_dims : sequence of str, default ("chain", "draw") + Sample (batch) dimensions of the dataset over which to compute the deterministics. + merge_dataset : bool, default False + Whether to extend the original dataset or return a new one. + progressbar : bool, default True + Whether to display a progress bar in the command line. + progressbar_theme : Theme, optional + Custom theme for the progress bar. + compile_kwargs: dict, optional + Additional arguments passed to `model.compile_fn`. + + Returns + ------- + Dataset + Dataset with values for the deterministics. + + + Examples + -------- + .. code:: python + + import pymc as pm + + with pm.Model(coords={"group": (0, 2, 4)}) as m: + mu_raw = pm.Normal("mu_raw", 0, 1, dims="group") + mu = pm.Deterministic("mu", mu_raw.cumsum(), dims="group") + + trace = pm.sample(var_names=["mu_raw"], chains=2, tune=5 draws=5) + + assert "mu" not in trace.posterior + + with m: + trace.posterior = pm.compute_deterministics(trace.posterior, merge_dataset=True) + + assert "mu" in trace.posterior + + + """ + model = modelcontext(model) + + if var_names is None: + deterministics = model.deterministics + else: + deterministics = [model[var_name] for var_name in var_names] + if not set(deterministics).issubset(set(model.deterministics)): + raise ValueError("Not all var_names corresponded to model deterministics") + + fn = model.compile_fn( + inputs=model.free_RVs, + outs=deterministics, + on_unused_input="ignore", + **(compile_kwargs or {}), + ) + + coords, dims = coords_and_dims_for_inferencedata(model) + + new_dataset = apply_function_over_dataset( + fn, + dataset[[rv.name for rv in model.free_RVs]], + output_var_names=[det.name for det in model.deterministics], + dims=dims, + coords=coords, + sample_dims=sample_dims, + progressbar=progressbar, + ) + + if merge_dataset: + new_dataset = xarray.merge([dataset, new_dataset], compat="override") + + return new_dataset diff --git a/pymc/sampling/forward.py b/pymc/sampling/forward.py index 814ed8de926..fe0f2085bb0 100644 --- a/pymc/sampling/forward.py +++ b/pymc/sampling/forward.py @@ -49,7 +49,7 @@ import pymc as pm -from pymc.backends.arviz import _DefaultTrace +from pymc.backends.arviz import _DefaultTrace, dataset_to_point_list from pymc.backends.base import MultiTrace from pymc.blocking import PointType from pymc.model import Model, modelcontext @@ -57,7 +57,6 @@ from pymc.util import ( RandomState, _get_seeds_per_chain, - dataset_to_point_list, default_progress_theme, get_default_varnames, point_wrapper, diff --git a/pymc/stats/log_density.py b/pymc/stats/log_density.py index b4a6ca742bf..5b6406d02b1 100644 --- a/pymc/stats/log_density.py +++ b/pymc/stats/log_density.py @@ -12,18 +12,16 @@ # See the License for the specific language governing permissions and # limitations under the License. from collections.abc import Sequence -from typing import cast +from typing import Literal -from arviz import InferenceData, dict_to_dataset -from rich.console import Console -from rich.progress import Progress +from arviz import InferenceData +from xarray import Dataset -import pymc - -from pymc.backends.arviz import _DefaultTrace, coords_and_dims_for_inferencedata +from pymc.backends.arviz import ( + apply_function_over_dataset, + coords_and_dims_for_inferencedata, +) from pymc.model import Model, modelcontext -from pymc.pytensorf import PointFunc -from pymc.util import dataset_to_point_list, default_progress_theme __all__ = ("compute_log_likelihood", "compute_log_prior") @@ -113,10 +111,10 @@ def compute_log_density( var_names: Sequence[str] | None = None, extend_inferencedata: bool = True, model: Model | None = None, - kind="likelihood", + kind: Literal["likelihood", "prior"] = "likelihood", sample_dims: Sequence[str] = ("chain", "draw"), progressbar=True, -): +) -> InferenceData | Dataset: """ Compute elemwise log_likelihood or log_prior of model given InferenceData with posterior group """ @@ -159,40 +157,20 @@ def compute_log_density( outs=model.logp(vars=vars, sum=False), on_unused_input="ignore", ) - elemwise_logdens_fn = cast(PointFunc, elemwise_logdens_fn) finally: model.rvs_to_values = original_rvs_to_values model.rvs_to_transforms = original_rvs_to_transforms - # Ignore Deterministics - posterior_values = posterior[[rv.name for rv in model.free_RVs]] - posterior_pts, stacked_dims = dataset_to_point_list(posterior_values, sample_dims) - - n_pts = len(posterior_pts) - logdens_dict = _DefaultTrace(n_pts) - - with Progress(console=Console(theme=default_progress_theme)) as progress: - task = progress.add_task("Computing log density...", total=n_pts, visible=progressbar) - for idx in range(n_pts): - logdenss_pts = elemwise_logdens_fn(posterior_pts[idx]) - for rv_name, rv_logdens in zip(var_names, logdenss_pts): - logdens_dict.insert(rv_name, rv_logdens, idx) - progress.update(task, advance=1) - - logdens_trace = logdens_dict.trace_dict - for key, array in logdens_trace.items(): - logdens_trace[key] = array.reshape( - (*[len(coord) for coord in stacked_dims.values()], *array.shape[1:]) - ) - coords, dims = coords_and_dims_for_inferencedata(model) - logdens_dataset = dict_to_dataset( - logdens_trace, - library=pymc, + + logdens_dataset = apply_function_over_dataset( + elemwise_logdens_fn, + posterior[[rv.name for rv in model.free_RVs]], + output_var_names=var_names, + sample_dims=sample_dims, dims=dims, coords=coords, - default_dims=list(sample_dims), - skip_event_dims=True, + progressbar=progressbar, ) if extend_inferencedata: diff --git a/pymc/util.py b/pymc/util.py index a3f45e889c9..ccf97c89a31 100644 --- a/pymc/util.py +++ b/pymc/util.py @@ -16,7 +16,7 @@ import warnings from collections.abc import Sequence -from typing import Any, NewType, cast +from typing import NewType, cast import arviz import cloudpickle @@ -31,6 +31,20 @@ from pymc.exceptions import BlockModelAccessError + +def __getattr__(name): + if name == "dataset_to_point_list": + warnings.warn( + f"{name} has been moved to backends.arviz. Importing from util will fail in a future release.", + FutureWarning, + ) + from pymc.backends.arviz import dataset_to_point_list + + return dataset_to_point_list + + raise AttributeError(f"module {__name__} has no attribute {name}") + + VarName = NewType("VarName", str) default_progress_theme = Theme( @@ -247,29 +261,6 @@ def enhanced(*args, **kwargs): return enhanced -def dataset_to_point_list( - ds: xarray.Dataset | dict[str, xarray.DataArray], sample_dims: Sequence[str] -) -> tuple[list[dict[str, np.ndarray]], dict[str, Any]]: - # All keys of the dataset must be a str - var_names = cast(list[str], list(ds.keys())) - for vn in var_names: - if not isinstance(vn, str): - raise ValueError(f"Variable names must be str, but dataset key {vn} is a {type(vn)}.") - num_sample_dims = len(sample_dims) - stacked_dims = {dim_name: ds[var_names[0]][dim_name] for dim_name in sample_dims} - transposed_dict = {vn: da.transpose(*sample_dims, ...) for vn, da in ds.items()} - stacked_dict = { - vn: da.values.reshape((-1, *da.shape[num_sample_dims:])) - for vn, da in transposed_dict.items() - } - points = [ - {vn: stacked_dict[vn][i, ...] for vn in var_names} - for i in range(np.prod([len(coords) for coords in stacked_dims.values()])) - ] - # use the list of points - return cast(list[dict[str, np.ndarray]], points), stacked_dims - - def drop_warning_stat(idata: arviz.InferenceData) -> arviz.InferenceData: """Returns a new ``InferenceData`` object with the "warning" stat removed from sample stats groups. diff --git a/tests/backends/test_arviz.py b/tests/backends/test_arviz.py index 55eb00cc1da..2fa85e12daa 100644 --- a/tests/backends/test_arviz.py +++ b/tests/backends/test_arviz.py @@ -16,6 +16,7 @@ import numpy as np import pytensor.tensor as pt import pytest +import xarray from arviz import InferenceData from arviz.tests.helpers import check_multiple_attrs @@ -26,6 +27,7 @@ from pymc.backends.arviz import ( InferenceDataConverter, + dataset_to_point_list, predictions_to_inference_data, to_inference_data, ) @@ -776,3 +778,34 @@ def test_save_warmup_issue_1208_after_3_9(self): assert not fails assert idata.posterior.sizes["chain"] == 2 assert idata.posterior.sizes["draw"] == 30 + + +class TestDatasetToPointList: + @pytest.mark.parametrize("input_type", ("dict", "Dataset")) + def test_dataset_to_point_list(self, input_type): + if input_type == "dict": + ds = {} + elif input_type == "Dataset": + ds = xarray.Dataset() + ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw")) + pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"]) + assert isinstance(pl, list) + assert len(pl) == 6 + assert isinstance(pl[0], dict) + assert isinstance(pl[0]["A"], np.ndarray) + + def test_transposed_dataset_to_point_list(self): + ds = xarray.Dataset() + ds["A"] = xarray.DataArray([[[1, 2, 3], [2, 3, 4]]] * 5, dims=("team", "draw", "chain")) + pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"]) + assert isinstance(pl, list) + assert len(pl) == 6 + assert isinstance(pl[0], dict) + assert isinstance(pl[0]["A"], np.ndarray) + + def test_dataset_to_point_list_str_key(self): + # Check that non-str keys are caught + ds = xarray.Dataset() + ds[3] = xarray.DataArray([1, 2, 3]) + with pytest.raises(ValueError, match="must be str"): + dataset_to_point_list(ds, sample_dims=["chain", "draw"]) diff --git a/tests/sampling/test_deterministic.py b/tests/sampling/test_deterministic.py new file mode 100644 index 00000000000..f693a788c57 --- /dev/null +++ b/tests/sampling/test_deterministic.py @@ -0,0 +1,77 @@ +# Copyright 2024 The PyMC Developers +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +import numpy as np +import pytest + +from numpy.testing import assert_allclose + +from pymc.distributions import Normal +from pymc.model.core import Deterministic, Model +from pymc.sampling.deterministic import compute_deterministics +from pymc.sampling.forward import sample_prior_predictive + +# Turn all warnings into errors for this module +pytestmark = pytest.mark.filterwarnings("error") + + +def test_compute_deterministics(): + with Model(coords={"group": (0, 2, 4)}) as m: + mu_raw = Normal("mu_raw", 0, 1, dims="group") + mu = Deterministic("mu", mu_raw.cumsum(), dims="group") + + sigma_raw = Normal("sigma_raw", 0, 1) + sigma = Deterministic("sigma", sigma_raw.exp()) + + dataset = sample_prior_predictive( + samples=5, model=m, var_names=["mu_raw", "sigma_raw"], random_seed=22 + ).prior + + # Test default + with m: + all_dets = compute_deterministics(dataset) + assert set(all_dets.data_vars.variables) == {"mu", "sigma"} + assert all_dets["mu"].dims == ("chain", "draw", "group") + assert all_dets["sigma"].dims == ("chain", "draw") + assert_allclose(all_dets["mu"], dataset["mu_raw"].cumsum("group")) + assert_allclose(all_dets["sigma"], np.exp(dataset["sigma_raw"])) + + # Test custom arguments + extended_with_mu = compute_deterministics( + dataset, + var_names=["mu"], + merge_dataset=True, + model=m, + compile_kwargs={"mode": "FAST_COMPILE"}, + progressbar=False, + ) + assert set(extended_with_mu.data_vars.variables) == {"mu_raw", "sigma_raw", "mu"} + assert extended_with_mu["mu"].dims == ("chain", "draw", "group") + assert_allclose(extended_with_mu["mu"], dataset["mu_raw"].cumsum("group")) + + +def test_docstring_example(): + import pymc as pm + + with pm.Model(coords={"group": (0, 2, 4)}) as m: + mu_raw = pm.Normal("mu_raw", 0, 1, dims="group") + mu = pm.Deterministic("mu", mu_raw.cumsum(), dims="group") + + trace = pm.sample(var_names=["mu_raw"], chains=2, tune=5, draws=5) + + assert "mu" not in trace.posterior + + with m: + trace.posterior = pm.compute_deterministics(trace.posterior, merge_dataset=True) + + assert "mu" in trace.posterior diff --git a/tests/test_util.py b/tests/test_util.py index 61d916249e7..2a8c4164ca8 100644 --- a/tests/test_util.py +++ b/tests/test_util.py @@ -26,7 +26,6 @@ from pymc.util import ( UNSET, _get_seeds_per_chain, - dataset_to_point_list, drop_warning_stat, get_value_vars_from_user_vars, hash_key, @@ -156,38 +155,6 @@ def fn(a=UNSET): assert "a=UNSET" in captured.out -@pytest.mark.parametrize("input_type", ("dict", "Dataset")) -def test_dataset_to_point_list(input_type): - if input_type == "dict": - ds = {} - elif input_type == "Dataset": - ds = xarray.Dataset() - ds["A"] = xarray.DataArray([[1, 2, 3]] * 2, dims=("chain", "draw")) - pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"]) - assert isinstance(pl, list) - assert len(pl) == 6 - assert isinstance(pl[0], dict) - assert isinstance(pl[0]["A"], np.ndarray) - - -def test_transposed_dataset_to_point_list(): - ds = xarray.Dataset() - ds["A"] = xarray.DataArray([[[1, 2, 3], [2, 3, 4]]] * 5, dims=("team", "draw", "chain")) - pl, _ = dataset_to_point_list(ds, sample_dims=["chain", "draw"]) - assert isinstance(pl, list) - assert len(pl) == 6 - assert isinstance(pl[0], dict) - assert isinstance(pl[0]["A"], np.ndarray) - - -def test_dataset_to_point_list_str_key(): - # Check that non-str keys are caught - ds = xarray.Dataset() - ds[3] = xarray.DataArray([1, 2, 3]) - with pytest.raises(ValueError, match="must be str"): - dataset_to_point_list(ds, sample_dims=["chain", "draw"]) - - def test_drop_warning_stat(): idata = arviz.from_dict( sample_stats={