diff --git a/chirho/observational/handlers/condition.py b/chirho/observational/handlers/condition.py index a097743c..ebdd43bc 100644 --- a/chirho/observational/handlers/condition.py +++ b/chirho/observational/handlers/condition.py @@ -1,10 +1,10 @@ -from typing import Callable, Generic, Hashable, Mapping, TypeVar, Union +from typing import Callable, Generic, Mapping, TypeVar, Union import pyro import torch from chirho.observational.internals import ObserveNameMessenger -from chirho.observational.ops import AtomicObservation, observe +from chirho.observational.ops import Observation, observe T = TypeVar("T") R = Union[float, torch.Tensor] @@ -62,7 +62,9 @@ class Observations(Generic[T], ObserveNameMessenger): a richer set of observational data types and enables counterfactual inference. """ - def __init__(self, data: Mapping[Hashable, AtomicObservation[T]]): + data: Mapping[str, Observation[T]] + + def __init__(self, data: Mapping[str, Observation[T]]): self.data = data super().__init__() diff --git a/chirho/robust/internals/linearize.py b/chirho/robust/internals/linearize.py index 436ae033..d02ef120 100644 --- a/chirho/robust/internals/linearize.py +++ b/chirho/robust/internals/linearize.py @@ -5,7 +5,7 @@ import torch from typing_extensions import Concatenate, ParamSpec -from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood +from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood from chirho.robust.internals.utils import ( ParamDict, make_flatten_unflatten, @@ -92,23 +92,17 @@ def f_Ax_flat(v: torch.Tensor) -> torch.Tensor: def make_empirical_fisher_vp( - func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], + batched_func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor], log_prob_params: ParamDict, data: Point[T], *args: P.args, **kwargs: P.kwargs, ) -> Callable[[ParamDict], ParamDict]: - batched_func_log_prob: Callable[[ParamDict, Point[T]], torch.Tensor] = torch.vmap( - lambda p, data: func_log_prob(p, data, *args, **kwargs), - in_dims=(None, 0), - randomness="different", - ) - N = data[next(iter(data))].shape[0] # type: ignore mean_vector = 1 / N * torch.ones(N) def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor: - return batched_func_log_prob(params, data) + return batched_func_log_prob(params, data, *args, **kwargs) def _empirical_fisher_vp(v: ParamDict) -> ParamDict: def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor: @@ -145,12 +139,11 @@ def linearize( num_samples=num_samples_outer, parallel=True, ) - predictive_params, func_predictive = make_functional_call(predictive) - log_prob = NMCLogPredictiveLikelihood( + batched_log_prob = BatchedNMCLogPredictiveLikelihood( model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting ) - log_prob_params, func_log_prob = make_functional_call(log_prob) + log_prob_params, batched_func_log_prob = make_functional_call(batched_log_prob) log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values()) if cg_iters is None: cg_iters = log_prob_params_numel @@ -166,23 +159,18 @@ def _fn( **kwargs: P.kwargs, ) -> ParamDict: with torch.no_grad(): - data: Point[T] = func_predictive(predictive_params, *args, **kwargs) + data: Point[T] = predictive(*args, **kwargs) data = {k: data[k] for k in points.keys()} fvp = make_empirical_fisher_vp( - func_log_prob, log_prob_params, data, *args, **kwargs + batched_func_log_prob, log_prob_params, data, *args, **kwargs ) pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp) pinned_fvp_batched = torch.func.vmap( lambda v: pinned_fvp(v), randomness="different" ) - batched_func_log_prob = torch.vmap( - lambda p, data: func_log_prob(p, data, *args, **kwargs), - in_dims=(None, 0), - randomness="different", - ) def bound_batched_func_log_prob(p: ParamDict) -> torch.Tensor: - return batched_func_log_prob(p, points) + return batched_func_log_prob(p, points, *args, **kwargs) if pointwise_influence: score_fn = torch.func.jacrev(bound_batched_func_log_prob) diff --git a/chirho/robust/internals/predictive.py b/chirho/robust/internals/predictive.py index 6e011de9..19369924 100644 --- a/chirho/robust/internals/predictive.py +++ b/chirho/robust/internals/predictive.py @@ -1,15 +1,21 @@ -import contextlib +import collections import math -import warnings -from typing import Any, Callable, Container, Generic, Optional, TypeVar +import typing +from typing import Any, Callable, Generic, Optional, TypeVar import pyro import torch from typing_extensions import ParamSpec -from chirho.indexed.handlers import DependentMaskMessenger -from chirho.observational.handlers import condition -from chirho.robust.internals.utils import guess_max_plate_nesting +from chirho.indexed.handlers import IndexPlatesMessenger +from chirho.indexed.ops import get_index_plates, indices_of +from chirho.observational.handlers.condition import Observations +from chirho.robust.internals.utils import ( + bind_leftmost_dim, + get_importance_traces, + site_is_delta, + unbind_leftmost_dim, +) from chirho.robust.ops import Point pyro.settings.set(module_local_params=True) @@ -20,27 +26,151 @@ T = TypeVar("T") -class _UnmaskNamedSites(DependentMaskMessenger): - names: Container[str] +class BatchedLatents(pyro.poutine.messenger.Messenger): + """ + Effect handler that adds a fresh batch dimension to all latent ``sample`` sites. + Similar to wrapping a Pyro model in a ``pyro.plate`` context, but uses the machinery + in ``chirho.indexed`` to automatically allocate and track the fresh batch dimension + based on the ``name`` argument to ``BatchedLatents`` . - def __init__(self, names: Container[str]): - self.names = names + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . - def get_mask( + :param int num_particles: Number of particles to use for parallelization. + :param str name: Name of the fresh batch dimension. + """ + + num_particles: int + name: str + + def __init__(self, num_particles: int, *, name: str = "__particles_mc"): + assert num_particles > 0 + assert len(name) > 0 + self.num_particles = num_particles + self.name = name + super().__init__() + + def _pyro_sample(self, msg: dict) -> None: + if ( + self.num_particles > 1 + and msg["value"] is None + and not pyro.poutine.util.site_is_factor(msg) + and not pyro.poutine.util.site_is_subsample(msg) + and not site_is_delta(msg) + and self.name not in indices_of(msg["fn"]) + ): + msg["fn"] = unbind_leftmost_dim( + msg["fn"].expand((1,) + msg["fn"].batch_shape), + self.name, + size=self.num_particles, + ) + + +class BatchedObservations(Generic[T], Observations[T]): + """ + Effect handler that takes a dictionary of observation values for ``sample`` sites + that are assumed to be batched along their leftmost dimension, adds a fresh named + dimension using the machinery in ``chirho.indexed``, and reshapes the observation + values so that the new ``chirho.observational.observe`` sites are batched along + the fresh named dimension. + + Useful in combination with ``pyro.infer.Predictive`` which returns a dictionary + of values whose leftmost dimension is a batch dimension over independent samples. + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + :param Point[T] data: Dictionary of observation values. + :param str name: Name of the fresh batch dimension. + """ + + name: str + + def __init__(self, data: Point[T], *, name: str = "__particles_data"): + assert len(name) > 0 + self.name = name + super().__init__(data) + + def _pyro_observe(self, msg: dict) -> None: + super()._pyro_observe(msg) + if msg["kwargs"]["name"] in self.data: + rv, obs = msg["args"] + event_dim = ( + len(rv.event_shape) + if hasattr(rv, "event_shape") + else msg["kwargs"].get("event_dim", 0) + ) + batch_obs = unbind_leftmost_dim(obs, self.name, event_dim=event_dim) + msg["args"] = (rv, batch_obs) + + +class PredictiveModel(Generic[P, T], torch.nn.Module): + """ + Given a Pyro model and guide, constructs a new model that behaves as if + the latent ``sample`` sites in the original model (i.e. the prior) + were replaced by their counterparts in the guide (i.e. the posterior). + + .. note:: Sites that only appear in the model are annotated in traces + produced by the predictive model with ``infer={"_model_predictive_site": True}`` . + + :param model: Pyro model. + :param guide: Pyro guide. + """ + + model: Callable[P, T] + guide: Callable[P, Any] + + def __init__( self, - dist: pyro.distributions.Distribution, - value: Optional[torch.Tensor], - device: torch.device = torch.device("cpu"), - name: Optional[str] = None, - ) -> torch.Tensor: - return torch.tensor(name is None or name in self.names, device=device) + model: Callable[P, T], + guide: Callable[P, Any], + ): + super().__init__() + self.model = model + self.guide = guide + + def forward(self, *args: P.args, **kwargs: P.kwargs) -> T: + with pyro.poutine.trace() as guide_tr: + self.guide(*args, **kwargs) + + block_guide_sample_sites = pyro.poutine.block( + hide=[ + name + for name, node in guide_tr.trace.nodes.items() + if node["type"] == "sample" + ] + ) + + with pyro.poutine.infer_config( + config_fn=lambda msg: {"_model_predictive_site": True} + ): + with block_guide_sample_sites: + with pyro.poutine.replay(trace=guide_tr.trace): + return self.model(*args, **kwargs) class PredictiveFunctional(Generic[P, T], torch.nn.Module): + """ + Functional that returns a batch of samples from the posterior predictive + distribution of a Pyro model given a guide. As with ``pyro.infer.Predictive`` , + the returned values are batched along their leftmost positional dimension. + + Similar to ``pyro.infer.Predictive(model, guide, num_samples, parallel=True)`` + but uses :class:`~PredictiveModel` to construct the predictive distribution + and infer the model ``sample`` sites whose values should be returned, + and uses :class:`~BatchedLatents` to parallelize over samples from the guide. + + .. warning:: ``PredictiveFunctional`` currently applies its own internal instance of + :class:`~chirho.indexed.handlers.IndexPlatesMessenger` , + so it may not behave as expected if used within another enclosing + :class:`~chirho.indexed.handlers.IndexPlatesMessenger` context. + + :param model: Pyro model. + :param guide: Pyro guide. + :param num_samples: Number of samples to return. + """ + model: Callable[P, Any] guide: Callable[P, Any] num_samples: int - max_plate_nesting: Optional[int] def __init__( self, @@ -49,59 +179,41 @@ def __init__( *, num_samples: int = 1, max_plate_nesting: Optional[int] = None, + name: str = "__particles_predictive", ): super().__init__() self.model = model self.guide = guide self.num_samples = num_samples - self.max_plate_nesting = max_plate_nesting - - def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: - if self.max_plate_nesting is None: - self.max_plate_nesting = guess_max_plate_nesting( - self.model, self.guide, *args, **kwargs - ) - - particles_plate = ( - contextlib.nullcontext() - if self.num_samples == 1 - else pyro.plate( - "__predictive_particles", - self.num_samples, - dim=-self.max_plate_nesting - 1, - ) + self._predictive_model: PredictiveModel[P, Any] = PredictiveModel(model, guide) + self._first_available_dim = ( + -max_plate_nesting - 1 if max_plate_nesting is not None else None ) + self._mc_plate_name = name - with pyro.poutine.trace() as guide_tr, particles_plate: - self.guide(*args, **kwargs) + def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]: + with IndexPlatesMessenger(first_available_dim=self._first_available_dim): + with pyro.poutine.trace() as model_tr: + with BatchedLatents(self.num_samples, name=self._mc_plate_name): + self._predictive_model(*args, **kwargs) - block_guide_sample_sites = pyro.poutine.block( - hide=[ - name - for name, node in guide_tr.trace.nodes.items() + return { + name: bind_leftmost_dim( + node["value"], + self._mc_plate_name, + event_dim=len(node["fn"].event_shape), + ) + for name, node in model_tr.trace.nodes.items() if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample(node) - ] - ) + and node["infer"].get("_model_predictive_site", False) + } - with pyro.poutine.trace() as model_tr: - with block_guide_sample_sites: - with pyro.poutine.replay(trace=guide_tr.trace), particles_plate: - self.model(*args, **kwargs) - return { - name: node["value"] - for name, node in model_tr.trace.nodes.items() - if node["type"] == "sample" - and not pyro.poutine.util.site_is_subsample(node) - } - - -class NMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): +class BatchedNMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): model: Callable[P, Any] guide: Callable[P, Any] num_samples: int - max_plate_nesting: Optional[int] def __init__( self, @@ -110,36 +222,71 @@ def __init__( *, num_samples: int = 1, max_plate_nesting: Optional[int] = None, + data_plate_name: str = "__particles_data", + mc_plate_name: str = "__particles_mc", ): super().__init__() self.model = model self.guide = guide self.num_samples = num_samples - self.max_plate_nesting = max_plate_nesting + self._first_available_dim = ( + -max_plate_nesting - 1 if max_plate_nesting is not None else None + ) + self._data_plate_name = data_plate_name + self._mc_plate_name = mc_plate_name def forward( self, data: Point[T], *args: P.args, **kwargs: P.kwargs ) -> torch.Tensor: - if self.max_plate_nesting is None: - self.max_plate_nesting = guess_max_plate_nesting( - self.model, self.guide, *args, **kwargs - ) - warnings.warn( - "Since max_plate_nesting is not specified, \ - the first call to NMCLogPredictiveLikelihood will not be seeded properly. \ - See https://github.com/BasisResearch/chirho/pull/408" - ) + get_nmc_traces = get_importance_traces(PredictiveModel(self.model, self.guide)) + + with IndexPlatesMessenger(first_available_dim=self._first_available_dim): + with BatchedLatents(self.num_samples, name=self._mc_plate_name): + with BatchedObservations(data, name=self._data_plate_name): + model_trace, guide_trace = get_nmc_traces(*args, **kwargs) + index_plates = get_index_plates() - masked_guide = pyro.poutine.mask(mask=False)(self.guide) - masked_model = _UnmaskNamedSites(names=set(data.keys()))( - condition(data=data)(self.model) + plate_name_to_dim = collections.OrderedDict( + (p, index_plates[p]) + for p in [self._mc_plate_name, self._data_plate_name] + if p in index_plates ) - log_weights = pyro.infer.importance.vectorized_importance_weights( - masked_model, - masked_guide, - *args, - num_samples=self.num_samples, - max_plate_nesting=self.max_plate_nesting, - **kwargs, - )[0] - return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) + plate_frames = set(plate_name_to_dim.values()) + + log_weights = typing.cast(torch.Tensor, 0.0) + for site in model_trace.nodes.values(): + if site["type"] != "sample": + continue + site_log_prob = site["log_prob"] + for f in site["cond_indep_stack"]: + if f.dim is not None and f not in plate_frames: + site_log_prob = site_log_prob.sum(f.dim, keepdim=True) + log_weights = log_weights + site_log_prob + + for site in guide_trace.nodes.values(): + if site["type"] != "sample": + continue + site_log_prob = site["log_prob"] + for f in site["cond_indep_stack"]: + if f.dim is not None and f not in plate_frames: + site_log_prob = site_log_prob.sum(f.dim, keepdim=True) + log_weights = log_weights - site_log_prob + + # sum out particle dimension and discard + if self._mc_plate_name in index_plates: + log_weights = torch.logsumexp( + log_weights, + dim=plate_name_to_dim[self._mc_plate_name].dim, + keepdim=True, + ) - math.log(self.num_samples) + plate_name_to_dim.pop(self._mc_plate_name) + + # move data plate dimension to the left + for name in reversed(plate_name_to_dim.keys()): + log_weights = bind_leftmost_dim(log_weights, name) + + # pack log_weights by squeezing out rightmost dimensions + for _ in range(len(log_weights.shape) - len(plate_name_to_dim)): + log_weights = log_weights.squeeze(-1) + + return log_weights diff --git a/chirho/robust/internals/utils.py b/chirho/robust/internals/utils.py index 7af0af4e..12094a1e 100644 --- a/chirho/robust/internals/utils.py +++ b/chirho/robust/internals/utils.py @@ -1,11 +1,15 @@ import contextlib import functools -from typing import Any, Callable, Dict, Mapping, Tuple, TypeVar +import math +from typing import Any, Callable, Dict, Mapping, Optional, Tuple, TypeVar import pyro import torch from typing_extensions import Concatenate, ParamSpec +from chirho.indexed.handlers import add_indices +from chirho.indexed.ops import IndexSet, get_index_plates, indices_of + P = ParamSpec("P") Q = ParamSpec("Q") S = TypeVar("S") @@ -106,3 +110,126 @@ def reset_rng_state(rng_state: T): yield pyro.util.set_rng_state(rng_state) finally: pyro.util.set_rng_state(prev_rng_state) + + +@functools.singledispatch +def unbind_leftmost_dim(v, name: str, size: int = 1, **kwargs): + """ + Helper function to move the leftmost dimension of a ``torch.Tensor`` + or ``pyro.distributions.Distribution`` or other batched value + into a fresh named dimension using the machinery in ``chirho.indexed`` , + allocating a new dimension with the given name if necessary + via an enclosing :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + + :param v: Batched value. + :param name: Name of the fresh dimension. + :param size: Size of the fresh dimension. If 1, the size is inferred from ``v`` . + """ + raise NotImplementedError + + +@unbind_leftmost_dim.register +def _unbind_leftmost_dim_tensor( + v: torch.Tensor, name: str, size: int = 1, *, event_dim: int = 0 +) -> torch.Tensor: + size = max(size, v.shape[0]) + v = v.expand((size,) + v.shape[1:]) + + if name not in get_index_plates(): + add_indices(IndexSet(**{name: set(range(size))})) + + new_dim: int = get_index_plates()[name].dim + orig_shape = v.shape + while new_dim - event_dim < -len(v.shape): + v = v[None] + if v.shape[0] == 1 and orig_shape[0] != 1: + v = torch.transpose(v, -len(orig_shape), new_dim - event_dim) + return v + + +@unbind_leftmost_dim.register +def _unbind_leftmost_dim_distribution( + v: pyro.distributions.Distribution, name: str, size: int = 1, **kwargs +) -> pyro.distributions.Distribution: + size = max(size, v.batch_shape[0]) + if v.batch_shape[0] != 1: + raise NotImplementedError("Cannot freely reshape distribution") + + if name not in get_index_plates(): + add_indices(IndexSet(**{name: set(range(size))})) + + new_dim: int = get_index_plates()[name].dim + orig_shape = v.batch_shape + + new_shape = (size,) + (1,) * (-new_dim - len(orig_shape)) + orig_shape[1:] + return v.expand(new_shape) + + +@functools.singledispatch +def bind_leftmost_dim(v, name: str, **kwargs): + """ + Helper function to move a named dimension managed by ``chirho.indexed`` + into a new unnamed dimension to the left of all named dimensions in the value. + + .. warning:: Must be used in conjunction with :class:`~chirho.indexed.handlers.IndexPlatesMessenger` . + """ + raise NotImplementedError + + +@bind_leftmost_dim.register +def _bind_leftmost_dim_tensor( + v: torch.Tensor, name: str, *, event_dim: int = 0, **kwargs +) -> torch.Tensor: + if name not in indices_of(v, event_dim=event_dim): + return v + return torch.transpose( + v[None], -len(v.shape) - 1, get_index_plates()[name].dim - event_dim + ) + + +def get_importance_traces( + model: Callable[P, Any], + guide: Optional[Callable[P, Any]] = None, +) -> Callable[P, Tuple[pyro.poutine.Trace, pyro.poutine.Trace]]: + """ + Thin functional wrapper around :func:`~pyro.infer.enum.get_importance_trace` + that cleans up the original interface to avoid unnecessary arguments + and efficiently supports using the prior in a model as a default guide. + + :param model: Model to run. + :param guide: Guide to run. If ``None``, use the prior in ``model`` as a guide. + :returns: A function that takes the same arguments as ``model`` and ``guide`` and returns + a tuple of importance traces ``(model_trace, guide_trace)``. + """ + + def _fn( + *args: P.args, **kwargs: P.kwargs + ) -> Tuple[pyro.poutine.Trace, pyro.poutine.Trace]: + if guide is not None: + model_trace, guide_trace = pyro.infer.enum.get_importance_trace( + "flat", math.inf, model, guide, args, kwargs + ) + return model_trace, guide_trace + else: # use prior as default guide, but don't run model twice + model_trace, _ = pyro.infer.enum.get_importance_trace( + "flat", math.inf, model, lambda *_, **__: None, args, kwargs + ) + + guide_trace = model_trace.copy() + for name, node in list(guide_trace.nodes.items()): + if node["type"] != "sample": + del model_trace.nodes[name] + elif pyro.poutine.util.site_is_factor(node) or node["is_observed"]: + del guide_trace.nodes[name] + return model_trace, guide_trace + + return _fn + + +def site_is_delta(msg: dict) -> bool: + d = msg["fn"] + while hasattr(d, "base_dist"): + d = d.base_dist + return isinstance(d, pyro.distributions.Delta) diff --git a/tests/robust/test_internals_compositions.py b/tests/robust/test_internals_compositions.py index 1e7aea90..b9924fab 100644 --- a/tests/robust/test_internals_compositions.py +++ b/tests/robust/test_internals_compositions.py @@ -5,11 +5,17 @@ import pytest import torch +from chirho.indexed.handlers import IndexPlatesMessenger +from chirho.indexed.ops import indices_of from chirho.robust.internals.linearize import ( conjugate_gradient_solve, make_empirical_fisher_vp, ) -from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood +from chirho.robust.internals.predictive import ( + BatchedLatents, + BatchedNMCLogPredictiveLikelihood, + BatchedObservations, +) from chirho.robust.internals.utils import make_functional_call, reset_rng_state from .robust_fixtures import SimpleGuide, SimpleModel @@ -21,7 +27,7 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): model = SimpleModel() guide = SimpleGuide() model(), guide() # initialize - log_prob = NMCLogPredictiveLikelihood(model, guide, num_samples=100) + log_prob = BatchedNMCLogPredictiveLikelihood(model, guide, num_samples=100) log_prob_params, func_log_prob = make_functional_call(log_prob) func_log_prob = reset_rng_state(pyro.util.get_rng_state())(func_log_prob) @@ -49,6 +55,7 @@ def test_empirical_fisher_vp_nmclikelihood_cg_composition(): # For this model, fvp for loc_a is zero. See # https://github.com/BasisResearch/chirho/issues/427 assert fvp(v)["guide.loc_a"].abs().max() == 0 + assert all(fvp_vk.shape == v[k].shape for k, fvp_vk in fvp(v).items()) solve_one = cg_solver(fvp, v) solve_two = cg_solver(fvp, v) @@ -88,7 +95,7 @@ def test_nmc_likelihood_seeded(link_fn): guide = SimpleGuide() model(), guide() # initialize - log_prob = NMCLogPredictiveLikelihood( + log_prob = BatchedNMCLogPredictiveLikelihood( model, guide, num_samples=3, max_plate_nesting=3 ) log_prob_params, func_log_prob = make_functional_call(log_prob) @@ -113,3 +120,96 @@ def test_nmc_likelihood_seeded(link_fn): # Check if fvp agrees across multiple calls of same `fvp` object assert torch.allclose(fvp(v)["guide.loc_a"], fvp(v)["guide.loc_a"]) assert torch.allclose(fvp(v)["guide.loc_b"], fvp(v)["guide.loc_b"]) + + +@pytest.mark.parametrize("pad_dim", [0, 1, 2]) +def test_batched_observations(pad_dim: int): + max_plate_nesting = 1 + pad_dim + obs_plate_name = "__dummy_plate__" + num_particles_obs = 3 + model = SimpleModel() + guide = SimpleGuide() + + model(), guide() # initialize + + predictive = pyro.infer.Predictive( + model, num_samples=num_particles_obs, return_sites=["y"] + ) + + test_data = predictive() + + with IndexPlatesMessenger(first_available_dim=-max_plate_nesting - 1): + with pyro.poutine.trace() as tr: + with BatchedObservations(test_data, name=obs_plate_name): + model() + + tr.trace.compute_log_prob() + + for name, node in tr.trace.nodes.items(): + if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample( + node + ): + if name in test_data: + assert obs_plate_name in indices_of(node["log_prob"], event_dim=0) + assert obs_plate_name in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + else: + assert obs_plate_name not in indices_of( + node["log_prob"], event_dim=0 + ) + assert obs_plate_name not in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + + +@pytest.mark.parametrize("pad_dim", [0, 1, 2]) +def test_batched_latents_observations(pad_dim: int): + max_plate_nesting = 1 + pad_dim + num_particles_latent = 5 + num_particles_obs = 3 + obs_plate_name = "__dummy_plate__" + latent_plate_name = "__dummy_latents__" + model = SimpleModel() + guide = SimpleGuide() + + model(), guide() # initialize + + predictive = pyro.infer.Predictive( + model, num_samples=num_particles_obs, return_sites=["y"] + ) + + test_data = predictive() + + with IndexPlatesMessenger(first_available_dim=-max_plate_nesting - 1): + with pyro.poutine.trace() as tr: + with BatchedLatents( + num_particles=num_particles_latent, name=latent_plate_name + ): + with BatchedObservations(test_data, name=obs_plate_name): + model() + + tr.trace.compute_log_prob() + + for name, node in tr.trace.nodes.items(): + if node["type"] == "sample" and not pyro.poutine.util.site_is_subsample( + node + ): + if name in test_data: + assert obs_plate_name in indices_of(node["log_prob"], event_dim=0) + assert obs_plate_name in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + assert latent_plate_name in indices_of( + node["log_prob"], event_dim=0 + ) + assert latent_plate_name not in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) + else: + assert latent_plate_name in indices_of( + node["log_prob"], event_dim=0 + ) + assert latent_plate_name in indices_of( + node["value"], event_dim=len(node["fn"].event_shape) + ) diff --git a/tests/robust/test_performance.py b/tests/robust/test_performance.py new file mode 100644 index 00000000..b1ec08f2 --- /dev/null +++ b/tests/robust/test_performance.py @@ -0,0 +1,178 @@ +import math +import time +import warnings +from functools import partial +from typing import Any, Callable, Container, Generic, Optional, TypeVar + +import pyro +import pytest +import torch +from typing_extensions import ParamSpec + +from chirho.indexed.handlers import DependentMaskMessenger +from chirho.observational.handlers import condition +from chirho.robust.internals.linearize import make_empirical_fisher_vp +from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood +from chirho.robust.internals.utils import guess_max_plate_nesting, make_functional_call +from chirho.robust.ops import Point + +from .robust_fixtures import SimpleGuide, SimpleModel + +pyro.settings.set(module_local_params=True) + +P = ParamSpec("P") +Q = ParamSpec("Q") +S = TypeVar("S") +T = TypeVar("T") + + +class _UnmaskNamedSites(DependentMaskMessenger): + names: Container[str] + + def __init__(self, names: Container[str]): + self.names = names + + def get_mask( + self, + dist: pyro.distributions.Distribution, + value: Optional[torch.Tensor], + device: torch.device = torch.device("cpu"), + name: Optional[str] = None, + ) -> torch.Tensor: + return torch.tensor(name is None or name in self.names, device=device) + + +class OldNMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module): + model: Callable[P, Any] + guide: Callable[P, Any] + num_samples: int + max_plate_nesting: Optional[int] + + def __init__( + self, + model: torch.nn.Module, + guide: torch.nn.Module, + *, + num_samples: int = 1, + max_plate_nesting: Optional[int] = None, + ): + super().__init__() + self.model = model + self.guide = guide + self.num_samples = num_samples + self.max_plate_nesting = max_plate_nesting + + def forward( + self, data: Point[T], *args: P.args, **kwargs: P.kwargs + ) -> torch.Tensor: + if self.max_plate_nesting is None: + self.max_plate_nesting = guess_max_plate_nesting( + self.model, self.guide, *args, **kwargs + ) + warnings.warn( + "Since max_plate_nesting is not specified, \ + the first call to NMCLogPredictiveLikelihood will not be seeded properly. \ + See https://github.com/BasisResearch/chirho/pull/408" + ) + + masked_guide = pyro.poutine.mask(mask=False)(self.guide) + masked_model = _UnmaskNamedSites(names=set(data.keys()))( + condition(data=data)(self.model) + ) + log_weights = pyro.infer.importance.vectorized_importance_weights( + masked_model, + masked_guide, + *args, + num_samples=self.num_samples, + max_plate_nesting=self.max_plate_nesting, + **kwargs, + )[0] + return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples) + + +class SimpleMultivariateGaussianModel(pyro.nn.PyroModule): + def __init__(self, p): + super().__init__() + self.p = p + + def forward(self): + loc = pyro.sample( + "loc", pyro.distributions.Normal(torch.zeros(self.p), 1.0).to_event(1) + ) + cov_mat = torch.eye(self.p) + return pyro.sample("y", pyro.distributions.MultivariateNormal(loc, cov_mat)) + + +class SimpleMultivariateGuide(torch.nn.Module): + def __init__(self, p): + super().__init__() + self.loc_ = torch.nn.Parameter(torch.rand((p,))) + self.p = p + + def forward(self): + return pyro.sample("loc", pyro.distributions.Normal(self.loc_, 1).to_event(1)) + + +model_guide_types = [ + ( + partial(SimpleMultivariateGaussianModel, p=500), + partial(SimpleMultivariateGuide, p=500), + ), + (SimpleModel, SimpleGuide), +] + + +@pytest.mark.skip(reason="This test is too slow to run on CI") +@pytest.mark.parametrize("model_guide", model_guide_types) +def test_empirical_fisher_vp_performance_with_likelihood(model_guide): + num_monte_carlo = 10000 + model_family, guide_family = model_guide + + model = model_family() + guide = guide_family() + + model() + guide() + + start_time = time.time() + data = pyro.infer.Predictive( + model, guide=guide, num_samples=num_monte_carlo, return_sites=["y"] + )() + end_time = time.time() + print("Data generation time (s): ", end_time - start_time) + + log1_prob_params, func1_log_prob = make_functional_call( + OldNMCLogPredictiveLikelihood(model, guide, max_plate_nesting=1) + ) + batched_func1_log_prob = torch.func.vmap( + func1_log_prob, in_dims=(None, 0), randomness="different" + ) + + log2_prob_params, func2_log_prob = make_functional_call( + BatchedNMCLogPredictiveLikelihood(model, guide) + ) + + fisher_hessian_vmapped = make_empirical_fisher_vp( + batched_func1_log_prob, log1_prob_params, data + ) + + fisher_hessian_batched = make_empirical_fisher_vp( + func2_log_prob, log2_prob_params, data + ) + + v = { + k: torch.ones_like(v) if k != "guide.loc_a" else torch.zeros_like(v) + for k, v in log1_prob_params.items() + } + + func2_log_prob(log2_prob_params, data) + + start_time = time.time() + fisher_hessian_vmapped(v) + end_time = time.time() + print("Hessian vmapped time (s): ", end_time - start_time) + + start_time = time.time() + fisher_hessian_batched(v) + end_time = time.time() + print("Hessian manual batched time (s): ", end_time - start_time)