Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove guide argument from influence_fn and linearize #489

Merged
merged 12 commits into from
Jan 9, 2024
18 changes: 9 additions & 9 deletions chirho/robust/handlers/estimators.py
Original file line number Diff line number Diff line change
@@ -1,13 +1,16 @@
from typing import Any, Callable
from typing import Any, Callable, TypeVar

from typing_extensions import Concatenate
from typing_extensions import Concatenate, ParamSpec

from chirho.robust.ops import Functional, P, Point, S, T, influence_fn
from chirho.robust.ops import Functional, Point, influence_fn

P = ParamSpec("P")
S = TypeVar("S")
T = TypeVar("T")


def one_step_correction(
model: Callable[P, Any],
guide: Callable[P, Any],
functional: Functional[P, S],
**influence_kwargs,
) -> Callable[Concatenate[Point[T], P], S]:
Expand All @@ -18,10 +21,7 @@ def one_step_correction(

:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param guide: Python callable containing Pyro primitives.
:type guide: Callable[P, Any]
:param functional: model summary of interest, which is a function of the
model and guide.
:param functional: model summary of interest, which is a function of the model.
:type functional: Functional[P, S]
:return: function to compute the one-step correction
:rtype: Callable[Concatenate[Point[T], P], S]
Expand All @@ -33,7 +33,7 @@ def one_step_correction(
"""
influence_kwargs_one_step = influence_kwargs.copy()
influence_kwargs_one_step["pointwise_influence"] = False
eif_fn = influence_fn(model, guide, functional, **influence_kwargs_one_step)
eif_fn = influence_fn(model, functional, **influence_kwargs_one_step)

def _one_step(test_data: Point[T], *args, **kwargs) -> S:
return eif_fn(test_data, *args, **kwargs)
Expand Down
140 changes: 140 additions & 0 deletions chirho/robust/handlers/predictive.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,140 @@
from typing import Any, Callable, Generic, Optional, TypeVar

import pyro
import torch
from typing_extensions import ParamSpec

from chirho.indexed.handlers import IndexPlatesMessenger
from chirho.robust.internals.nmc import BatchedLatents
from chirho.robust.internals.utils import bind_leftmost_dim
from chirho.robust.ops import Point

P = ParamSpec("P")
S = TypeVar("S")
T = TypeVar("T")


class PredictiveModel(Generic[P, T], torch.nn.Module):
"""
Given a Pyro model and guide, constructs a new model that behaves as if
the latent ``sample`` sites in the original model (i.e. the prior)
were replaced by their counterparts in the guide (i.e. the posterior).

.. note:: Sites that only appear in the model are annotated in traces
produced by the predictive model with ``infer={"_model_predictive_site": True}`` .

:param model: Pyro model.
:param guide: Pyro guide.
"""

model: Callable[P, T]
guide: Optional[Callable[P, Any]]

def __init__(
self,
model: Callable[P, T],
guide: Optional[Callable[P, Any]] = None,
):
super().__init__()
self.model = model
self.guide = guide

def forward(self, *args: P.args, **kwargs: P.kwargs) -> T:
"""
Returns a sample from the posterior predictive distribution.

:return: Sample from the posterior predictive distribution.
:rtype: T
"""
with pyro.poutine.infer_config(
config_fn=lambda msg: {"_model_predictive_site": False}
):
with pyro.poutine.trace() as guide_tr:
if self.guide is not None:
self.guide(*args, **kwargs)

block_guide_sample_sites = pyro.poutine.block(
hide=[
name
for name, node in guide_tr.trace.nodes.items()
if node["type"] == "sample"
]
)

with pyro.poutine.infer_config(
config_fn=lambda msg: {"_model_predictive_site": True}
):
with block_guide_sample_sites:
with pyro.poutine.replay(trace=guide_tr.trace):
return self.model(*args, **kwargs)


class PredictiveFunctional(Generic[P, T], torch.nn.Module):
"""
Functional that returns a batch of samples from the predictive
distribution of a Pyro model. As with ``pyro.infer.Predictive`` ,
the returned values are batched along their leftmost positional dimension.

Similar to ``pyro.infer.Predictive(model, guide, num_samples, parallel=True)``
when :class:`~chirho.robust.handlers.predictive.PredictiveModel` is used to construct
the ``model`` argument and infer the ``sample`` sites whose values should be returned,
and uses :class:`~BatchedLatents` to parallelize over samples from the model.

.. warning:: ``PredictiveFunctional`` currently applies its own internal instance of
:class:`~chirho.indexed.handlers.IndexPlatesMessenger` ,
so it may not behave as expected if used within another enclosing
:class:`~chirho.indexed.handlers.IndexPlatesMessenger` context.

:param model: Pyro model.
:param num_samples: Number of samples to return.
"""

model: Callable[P, Any]
num_samples: int

def __init__(
self,
model: torch.nn.Module,
*,
num_samples: int = 1,
max_plate_nesting: Optional[int] = None,
name: str = "__particles_predictive",
):
super().__init__()
self.model = model
self.num_samples = num_samples
self._first_available_dim = (
-max_plate_nesting - 1 if max_plate_nesting is not None else None
)
self._mc_plate_name = name

def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]:
"""
Returns a batch of samples from the posterior predictive distribution.

:return: Dictionary of samples from the posterior predictive distribution.
:rtype: Point[T]
"""
with IndexPlatesMessenger(first_available_dim=self._first_available_dim):
with pyro.poutine.trace() as model_tr:
with BatchedLatents(self.num_samples, name=self._mc_plate_name):
with pyro.poutine.infer_config(
config_fn=lambda msg: {
"_model_predictive_site": msg["infer"].get(
"_model_predictive_site", True
)
}
):
self.model(*args, **kwargs)

return {
name: bind_leftmost_dim(
node["value"],
self._mc_plate_name,
event_dim=len(node["fn"].event_shape),
)
for name, node in model_tr.trace.nodes.items()
if node["type"] == "sample"
and not pyro.poutine.util.site_is_subsample(node)
and node["infer"].get("_model_predictive_site", False)
}
27 changes: 11 additions & 16 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from typing_extensions import Concatenate, ParamSpec

from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood
from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood
from chirho.robust.internals.utils import (
ParamDict,
make_flatten_unflatten,
Expand Down Expand Up @@ -220,7 +220,6 @@ def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor:

def linearize(
model: Callable[P, Any],
guide: Callable[P, Any],
*,
num_samples_outer: int,
num_samples_inner: Optional[int] = None,
Expand All @@ -231,26 +230,23 @@ def linearize(
) -> Callable[Concatenate[Point[T], P], ParamDict]:
r"""
Returns the influence function associated with the parameters
of ``guide`` and probabilistic program ``model``. This function
of a normalized probabilistic program ``model``. This function
computes the following quantity at an arbitrary point :math:`x^{\prime}`:

.. math::

\left[-\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) \right]
\nabla_{\phi} \log \tilde{p}_{\phi}(x^{\prime}), \quad
\tilde{p}_{\phi}(x) = \int p(x \mid \theta) q_{\phi}(\theta) d\theta,
\tilde{p}_{\phi}(x) = \int p_{\phi}(x, \theta) d\theta,

where :math:`\phi` corresponds to ``log_prob_params``,
:math:`p(x \mid \theta)` denotes the ``model``, :math:`q_{\phi}` denotes the ``guide``,
:math:`p(x, \theta)` denotes the ``model``,
:math:`\tilde{p}_{\phi}` denotes the predictive distribution ``log_prob`` induced
from the ``model`` and ``guide``, and :math:`\{x_n\}_{n=1}^N` are the
from the ``model``, and :math:`\{x_n\}_{n=1}^N` are the
data points drawn iid from the predictive distribution.

:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param guide: Python callable containing Pyro primitives.
Must only contain continuous latent variables.
:type guide: Callable[P, Any]
:param num_samples_outer: number of Monte Carlo samples to
approximate Fisher information in :func:`make_empirical_fisher_vp`
:type num_samples_outer: int
Expand All @@ -276,10 +272,12 @@ def linearize(
**Example usage**:

.. code-block:: python

import pyro
import pyro.distributions as dist
import torch

from chirho.robust.handlers.predictive import PredictiveModel
from chirho.robust.internals.linearize import linearize

pyro.settings.set(module_local_params=True)
Expand Down Expand Up @@ -312,8 +310,7 @@ def forward(self):
)
points = predictive()
influence = linearize(
model,
guide,
PredictiveModel(model, guide),
num_samples_outer=1000,
num_samples_inner=1000,
)
Expand All @@ -327,24 +324,22 @@ def forward(self):
can result in different values. To reduce variance, increase ``num_samples_outer`` and
``num_samples_inner`` in ``linearize_kwargs``.

* Currently, ``model`` and ``guide`` cannot contain any ``pyro.param`` statements.
* Currently, ``model`` cannot contain any ``pyro.param`` statements.
This issue will be addressed in a future release:
https://github.com/BasisResearch/chirho/issues/393.
"""
assert isinstance(model, torch.nn.Module)
assert isinstance(guide, torch.nn.Module)
if num_samples_inner is None:
num_samples_inner = num_samples_outer**2

predictive = pyro.infer.Predictive(
model,
guide=guide,
num_samples=num_samples_outer,
parallel=True,
)

batched_log_prob = BatchedNMCLogPredictiveLikelihood(
model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting
batched_log_prob = BatchedNMCLogMarginalLikelihood(
model, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting
)
log_prob_params, batched_func_log_prob = make_functional_call(batched_log_prob)
log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values())
Expand Down
Loading
Loading