Skip to content

Commit

Permalink
Added documentation for chirho.robust (#470)
Browse files Browse the repository at this point in the history
* documentation

* documentation clean up w/ eli

* fix lint issue
  • Loading branch information
agrawalraj authored Jan 2, 2024
1 parent 878eb0d commit 3cfe319
Show file tree
Hide file tree
Showing 6 changed files with 420 additions and 12 deletions.
20 changes: 20 additions & 0 deletions chirho/robust/handlers/estimators.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,26 @@ def one_step_correction(
functional: Optional[Functional[P, S]] = None,
**influence_kwargs,
) -> Callable[Concatenate[Point[T], P], S]:
"""
Returns a function that computes the one-step correction for the
functional at a specified set of test points as discussed in
[1].
:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param guide: Python callable containing Pyro primitives.
:type guide: Callable[P, Any]
:param functional: model summary of interest, which is a function of the
model and guide. If ``None``, defaults to :class:`PredictiveFunctional`.
:type functional: Optional[Functional[P, S]], optional
:return: function to compute the one-step correction
:rtype: Callable[Concatenate[Point[T], P], S]
**References**
[1] `Semiparametric doubly robust targeted double machine learning: a review`,
Edward H. Kennedy, 2022.
"""
influence_kwargs_one_step = influence_kwargs.copy()
influence_kwargs_one_step["pointwise_influence"] = False
eif_fn = influence_fn(model, guide, functional, **influence_kwargs_one_step)
Expand Down
227 changes: 215 additions & 12 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,20 +27,26 @@ def _flat_conjugate_gradient_solve(
cg_iters: Optional[int] = None,
residual_tol: float = 1e-3,
) -> torch.Tensor:
r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312.
"""
Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312.
:param f_Ax: a function to compute matrix vector products over a batch
of vectors ``x``.
:type f_Ax: Callable[[torch.Tensor], torch.Tensor]
:param b: batch of right hand sides of the equation to solve.
:type b: torch.Tensor
:param cg_iters: number of conjugate iterations to run, defaults to None
:type cg_iters: Optional[int], optional
:param residual_tol: tolerance for convergence, defaults to 1e-3
:type residual_tol: float, optional
:return: batch of solutions ``x*`` for equation Ax = b.
:rtype: torch.Tensor
Args:
f_Ax (callable): A function to compute matrix vector product.
b (torch.Tensor): Right hand side of the equation to solve.
cg_iters (int): Number of iterations to run conjugate gradient
algorithm.
residual_tol (float): Tolerence for convergence.
.. note::
Returns:
torch.Tensor: Solution x* for equation Ax = b.
Code is adapted from
https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py # noqa: E501
Notes: This code is adapted from
https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py
"""
assert len(b.shape), "b must be a 2D matrix"

Expand Down Expand Up @@ -81,6 +87,17 @@ def _batched_product(a: torch.Tensor, B: torch.Tensor) -> torch.Tensor:


def conjugate_gradient_solve(f_Ax: Callable[[T], T], b: T, **kwargs) -> T:
"""
Use Conjugate Gradient iteration to solve Ax = b.
:param f_Ax: a function to compute matrix vector products over a batch
of vectors ``x``.
:type f_Ax: Callable[[T], T]
:param b: batch of right hand sides of the equation to solve.
:type b: T
:return: batch of solutions ``x*`` for equation Ax = b.
:rtype: T
"""
flatten, unflatten = make_flatten_unflatten(b)

def f_Ax_flat(v: torch.Tensor) -> torch.Tensor:
Expand All @@ -98,6 +115,90 @@ def make_empirical_fisher_vp(
*args: P.args,
**kwargs: P.kwargs,
) -> Callable[[ParamDict], ParamDict]:
r"""
Returns a function that computes the empirical Fisher vector product for an arbitrary
vector :math:`v` using only Hessian vector products via a batched version of
Perlmutter's trick [1].
.. math::
-\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) v,
where :math:`\phi` corresponds to ``log_prob_params``, :math:`\tilde{p}_{\phi}` denotes the
predictive distribution ``log_prob``, and :math:`x_n` are the data points in ``data``.
:param func_log_prob: computes the log probability of ``data`` given ``log_prob_params``
:type func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor]
:param log_prob_params: parameters of the predictive distribution
:type log_prob_params: ParamDict
:param data: data points
:type data: Point[T]
:param is_batched: if ``False``, ``func_log_prob`` is batched over ``data``
using ``torch.func.vmap``. Otherwise, assumes ``func_log_prob`` is already batched
over multiple data points. ``Defaults to False``.
:type is_batched: bool, optional
:return: a function that computes the empirical Fisher vector product for an arbitrary
vector :math:`v`
:rtype: Callable[[ParamDict], ParamDict]
**Example usage**:
.. code-block:: python
import pyro
import pyro.distributions as dist
import torch
from chirho.robust.internals.linearize import make_empirical_fisher_vp
pyro.settings.set(module_local_params=True)
class GaussianModel(pyro.nn.PyroModule):
def __init__(self, cov_mat: torch.Tensor):
super().__init__()
self.register_buffer("cov_mat", cov_mat)
def forward(self, loc):
pyro.sample(
"x", dist.MultivariateNormal(loc=loc, covariance_matrix=self.cov_mat)
)
def gaussian_log_prob(params, data_point, cov_mat):
with pyro.validation_enabled(False):
return dist.MultivariateNormal(
loc=params["loc"], covariance_matrix=cov_mat
).log_prob(data_point["x"])
v = torch.tensor([1.0, 0.0], requires_grad=False)
loc = torch.ones(2, requires_grad=True)
cov_mat = torch.ones(2, 2) + torch.eye(2)
func_log_prob = gaussian_log_prob
log_prob_params = {"loc": loc}
N_monte_carlo = 10000
data = pyro.infer.Predictive(GaussianModel(cov_mat), num_samples=N_monte_carlo)(loc)
empirical_fisher_vp_func = make_empirical_fisher_vp(
func_log_prob, log_prob_params, data, cov_mat=cov_mat
)
empirical_fisher_vp = empirical_fisher_vp_func({"loc": v})["loc"]
# Closed form solution for the Fisher vector product
# See "Multivariate normal distribution" in https://en.wikipedia.org/wiki/Fisher_information
prec_matrix = torch.linalg.inv(cov_mat)
true_vp = prec_matrix.mv(v)
assert torch.all(torch.isclose(empirical_fisher_vp, true_vp, atol=0.1))
**References**
[1] `Fast Exact Multiplication by the Hessian`,
Barak A. Pearlmutter, 1999.
"""
N = data[next(iter(data))].shape[0] # type: ignore
mean_vector = 1 / N * torch.ones(N)

Expand Down Expand Up @@ -125,9 +226,111 @@ def linearize(
num_samples_inner: Optional[int] = None,
max_plate_nesting: Optional[int] = None,
cg_iters: Optional[int] = None,
residual_tol: float = 1e-10,
residual_tol: float = 1e-4,
pointwise_influence: bool = True,
) -> Callable[Concatenate[Point[T], P], ParamDict]:
r"""
Returns the influence function associated with the parameters
of ``guide`` and probabilistic program ``model``. This function
computes the following quantity at an arbitrary point :math:`x^{\prime}`:
.. math::
\left[-\frac{1}{N} \sum_{n=1}^N \nabla_{\phi}^2 \log \tilde{p}_{\phi}(x_n) \right]
\nabla_{\phi} \log \tilde{p}_{\phi}(x^{\prime}), \quad
\tilde{p}_{\phi}(x) = \int p(x \mid \theta) q_{\phi}(\theta) d\theta,
where :math:`\phi` corresponds to ``log_prob_params``,
:math:`p(x \mid \theta)` denotes the ``model``, :math:`q_{\phi}` denotes the ``guide``,
:math:`\tilde{p}_{\phi}` denotes the predictive distribution ``log_prob`` induced
from the ``model`` and ``guide``, and :math:`\{x_n\}_{n=1}^N` are the
data points drawn iid from the predictive distribution.
:param model: Python callable containing Pyro primitives.
:type model: Callable[P, Any]
:param guide: Python callable containing Pyro primitives.
Must only contain continuous latent variables.
:type guide: Callable[P, Any]
:param num_samples_outer: number of Monte Carlo samples to
approximate Fisher information in :func:`make_empirical_fisher_vp`
:type num_samples_outer: int
:param num_samples_inner: number of Monte Carlo samples used in
:class:`BatchedNMCLogPredictiveLikelihood`. Defaults to ``num_samples_outer**2``.
:type num_samples_inner: Optional[int], optional
:param max_plate_nesting: bound on max number of nested :func:`pyro.plate`
contexts. Defaults to ``None``.
:type max_plate_nesting: Optional[int], optional
:param cg_iters: number of conjugate gradient steps used to
invert Fisher information matrix, defaults to None
:type cg_iters: Optional[int], optional
:param residual_tol: tolerance used to terminate conjugate gradients
early, defaults to 1e-4
:type residual_tol: float, optional
:param pointwise_influence: if ``True``, computes the influence function at each
point in ``points``. If ``False``, computes the efficient influence averaged
over ``points``. Defaults to True.
:type pointwise_influence: bool, optional
:return: the influence function associated with the parameters
:rtype: Callable[Concatenate[Point[T], P], ParamDict]
**Example usage**:
.. code-block:: python
import pyro
import pyro.distributions as dist
import torch
from chirho.robust.internals.linearize import linearize
pyro.settings.set(module_local_params=True)
class SimpleModel(pyro.nn.PyroModule):
def forward(self):
a = pyro.sample("a", dist.Normal(0, 1))
with pyro.plate("data", 3, dim=-1):
b = pyro.sample("b", dist.Normal(a, 1))
return pyro.sample("y", dist.Normal(b, 1))
class SimpleGuide(torch.nn.Module):
def __init__(self):
super().__init__()
self.loc_a = torch.nn.Parameter(torch.rand(()))
self.loc_b = torch.nn.Parameter(torch.rand((3,)))
def forward(self):
a = pyro.sample("a", dist.Normal(self.loc_a, 1))
with pyro.plate("data", 3, dim=-1):
b = pyro.sample("b", dist.Normal(self.loc_b, 1))
return {"a": a, "b": b}
model = SimpleModel()
guide = SimpleGuide()
predictive = pyro.infer.Predictive(
model, guide=guide, num_samples=10, return_sites=["y"]
)
points = predictive()
influence = linearize(
model,
guide,
num_samples_outer=1000,
num_samples_inner=1000,
)
influence(points)
.. note::
* Since the efficient influence function is approximated using Monte Carlo, the result
of this function is stochastic, i.e., evaluating this function on the same ``points``
can result in different values. To reduce variance, increase ``num_samples_outer`` and
``num_samples_inner`` in ``linearize_kwargs``.
* Currently, ``model`` and ``guide`` cannot contain any ``pyro.param`` statements.
This issue will be addressed in a future release:
https://github.com/BasisResearch/chirho/issues/393.
"""
assert isinstance(model, torch.nn.Module)
assert isinstance(guide, torch.nn.Module)
if num_samples_inner is None:
Expand Down
41 changes: 41 additions & 0 deletions chirho/robust/internals/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -128,6 +128,12 @@ def __init__(
self.guide = guide

def forward(self, *args: P.args, **kwargs: P.kwargs) -> T:
"""
Returns a sample from the posterior predictive distribution.
:return: Sample from the posterior predictive distribution.
:rtype: T
"""
with pyro.poutine.trace() as guide_tr:
self.guide(*args, **kwargs)

Expand Down Expand Up @@ -192,6 +198,12 @@ def __init__(
self._mc_plate_name = name

def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]:
"""
Returns a batch of samples from the posterior predictive distribution.
:return: Dictionary of samples from the posterior predictive distribution.
:rtype: Point[T]
"""
with IndexPlatesMessenger(first_available_dim=self._first_available_dim):
with pyro.poutine.trace() as model_tr:
with BatchedLatents(self.num_samples, name=self._mc_plate_name):
Expand All @@ -211,6 +223,27 @@ def forward(self, *args: P.args, **kwargs: P.kwargs) -> Point[T]:


class BatchedNMCLogPredictiveLikelihood(Generic[P, T], torch.nn.Module):
r"""
Approximates the log predictive likelihood induced by ``model`` and ``guide``
using Monte Carlo sampling at an arbitrary batch of :math:`N`
points :math:`\{x_n\}_{n=1}^N`.
.. math::
\log \left(\frac{1}{M} \sum_{m=1}^M p(x_n \mid \theta_m)\right),
\quad \theta_m \sim q_{\phi}(\theta),
where :math:`q_{\phi}(\theta)` is the guide and :math:`p(x_n \mid \theta_m)`
is the model conditioned on the latents from the guide.
:param model: Python callable containing Pyro primitives.
:type model: torch.nn.Module
:param guide: Python callable containing Pyro primitives.
Must only contain continuous latent variables.
:type guide: torch.nn.Module
:param num_samples: Number of Monte Carlo draws :math:`M`
used to approximate predictive distribution, defaults to 1
:type num_samples: int, optional
"""
model: Callable[P, Any]
guide: Callable[P, Any]
num_samples: int
Expand Down Expand Up @@ -238,6 +271,14 @@ def __init__(
def forward(
self, data: Point[T], *args: P.args, **kwargs: P.kwargs
) -> torch.Tensor:
"""
Computes the log predictive likelihood of ``data`` given ``model`` and ``guide``.
:param data: Dictionary of observations.
:type data: Point[T]
:return: Log predictive likelihood at each datapoint.
:rtype: torch.Tensor
"""
get_nmc_traces = get_importance_traces(PredictiveModel(self.model, self.guide))

with IndexPlatesMessenger(first_available_dim=self._first_available_dim):
Expand Down
Loading

0 comments on commit 3cfe319

Please sign in to comment.