Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Performance improvements for chirho.robust #459

Closed
wants to merge 11 commits into from
119 changes: 82 additions & 37 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,10 @@
import torch
from typing_extensions import Concatenate, ParamSpec

from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood
from chirho.robust.internals.predictive import (
NMCLogPredictiveLikelihood,
PointLogPredictiveLikelihood,
)
from chirho.robust.internals.utils import (
ParamDict,
make_flatten_unflatten,
Expand All @@ -25,7 +28,7 @@ def _flat_conjugate_gradient_solve(
b: torch.Tensor,
*,
cg_iters: Optional[int] = None,
residual_tol: float = 1e-10,
residual_tol: float = 1e-3,
) -> torch.Tensor:
r"""Use Conjugate Gradient iteration to solve Ax = b. Demmel p 312.

Expand All @@ -42,31 +45,41 @@ def _flat_conjugate_gradient_solve(
Notes: This code is adapted from
https://github.com/rlworkgroup/garage/blob/master/src/garage/torch/optimizers/conjugate_gradient_optimizer.py
"""
assert len(b.shape), "b must be a 2D matrix"

if cg_iters is None:
cg_iters = b.numel()
cg_iters = b.shape[1]
else:
cg_iters = min(cg_iters, b.shape[1])

def _batched_dot(x1: torch.Tensor, x2: torch.Tensor) -> torch.Tensor:
return (x1 * x2).sum(axis=-1) # type: ignore

def _batched_product(a: torch.Tensor, B: torch.Tensor) -> torch.Tensor:
return a.unsqueeze(0).t() * B

p = b.clone()
r = b.clone()
x = torch.zeros_like(b)
z = f_Ax(p)
rdotr = torch.dot(r, r)
v = rdotr / torch.dot(p, z)
rdotr = _batched_dot(r, r)
v = rdotr / _batched_dot(p, z)
newrdotr = rdotr
mu = newrdotr / rdotr

zeros_xr = torch.zeros_like(x)

for _ in range(cg_iters):
not_converged = rdotr > residual_tol
z = torch.where(not_converged, f_Ax(p), z)
v = torch.where(not_converged, rdotr / torch.dot(p, z), v)
x += torch.where(not_converged, v * p, zeros_xr)
r -= torch.where(not_converged, v * z, zeros_xr)
newrdotr = torch.where(not_converged, torch.dot(r, r), newrdotr)
not_converged_broadcasted = not_converged.unsqueeze(0).t()
z = torch.where(not_converged_broadcasted, f_Ax(p), z)
v = torch.where(not_converged, rdotr / _batched_dot(p, z), v)
x += torch.where(not_converged_broadcasted, _batched_product(v, p), zeros_xr)
r -= torch.where(not_converged_broadcasted, _batched_product(v, z), zeros_xr)
newrdotr = torch.where(not_converged, _batched_dot(r, r), newrdotr)
mu = torch.where(not_converged, newrdotr / rdotr, mu)
p = torch.where(not_converged, r + mu * p, p)
p = torch.where(not_converged_broadcasted, r + _batched_product(mu, p), p)
rdotr = torch.where(not_converged, newrdotr, rdotr)

if torch.all(~not_converged):
return x
return x


Expand All @@ -85,14 +98,20 @@ def make_empirical_fisher_vp(
func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor],
log_prob_params: ParamDict,
data: Point[T],
is_batched: bool = False,
*args: P.args,
**kwargs: P.kwargs,
) -> Callable[[ParamDict], ParamDict]:
batched_func_log_prob: Callable[[ParamDict, Point[T]], torch.Tensor] = torch.vmap(
lambda p, data: func_log_prob(p, data, *args, **kwargs),
in_dims=(None, 0),
randomness="different",
)
if not is_batched:
batched_func_log_prob: Callable[
[ParamDict, Point[T]], torch.Tensor
] = torch.vmap(
lambda p, data: func_log_prob(p, data, *args, **kwargs),
in_dims=(None, 0),
randomness="different",
)
else:
batched_func_log_prob = functools.partial(func_log_prob, *args, **kwargs)

N = data[next(iter(data))].shape[0] # type: ignore
mean_vector = 1 / N * torch.ones(N)
Expand All @@ -118,10 +137,11 @@ def linearize(
guide: Callable[P, Any],
*,
num_samples_outer: int,
is_point_estimate: bool = False,
num_samples_inner: Optional[int] = None,
max_plate_nesting: Optional[int] = None,
cg_iters: Optional[int] = None,
residual_tol: float = 1e-10,
residual_tol: float = 1e-4,
) -> Callable[Concatenate[Point[T], P], ParamDict]:
assert isinstance(model, torch.nn.Module)
assert isinstance(guide, torch.nn.Module)
Expand All @@ -136,32 +156,57 @@ def linearize(
)
predictive_params, func_predictive = make_functional_call(predictive)

log_prob = NMCLogPredictiveLikelihood(
model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting
if is_point_estimate:
log_prob_type = PointLogPredictiveLikelihood
is_batched = True
else:
log_prob_type = NMCLogPredictiveLikelihood
is_batched = False

log_prob = log_prob_type(
model,
guide,
max_plate_nesting=max_plate_nesting,
)
make_efvp = functools.partial(make_empirical_fisher_vp, is_batched=is_batched)
log_prob_params, func_log_prob = make_functional_call(log_prob)
score_fn = torch.func.grad(func_log_prob)

log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values())
if cg_iters is None:
cg_iters = log_prob_params_numel
else:
cg_iters = min(cg_iters, log_prob_params_numel)
cg_solver = functools.partial(
conjugate_gradient_solve, cg_iters=cg_iters, residual_tol=residual_tol
)

@functools.wraps(score_fn)
def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> ParamDict:
def _fn(
points: Point[T],
pointwise_influence: bool = True,
*args: P.args,
**kwargs: P.kwargs,
) -> ParamDict:
with torch.no_grad():
data: Point[T] = func_predictive(predictive_params, *args, **kwargs)
data = {k: data[k] for k in point.keys()}
fvp = make_empirical_fisher_vp(
func_log_prob, log_prob_params, data, *args, **kwargs
)

data = {k: data[k] for k in points.keys()}
fvp = make_efvp(func_log_prob, log_prob_params, data, *args, **kwargs)
pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp)
point_score: ParamDict = score_fn(log_prob_params, point, *args, **kwargs)
return cg_solver(pinned_fvp, point_score)
pinned_fvp_batched = torch.func.vmap(
lambda v: pinned_fvp(v), randomness="different"
)
if not is_point_estimate:
batched_func_log_prob = torch.vmap(
lambda p, data: func_log_prob(p, data, *args, **kwargs),
in_dims=(None, 0),
randomness="different",
)
else:
batched_func_log_prob = functools.partial(func_log_prob, *args, **kwargs)
if log_prob_params_numel > points[next(iter(points))].shape[0]:
score_fn = torch.func.jacrev(batched_func_log_prob)
else:
score_fn = torch.func.jacfwd(batched_func_log_prob, randomness="different")
point_scores: ParamDict = score_fn(log_prob_params, points)
if not pointwise_influence:
point_scores = {
k: v.mean(dim=0).unsqueeze(0) for k, v in point_scores.items()
}

return cg_solver(pinned_fvp_batched, point_scores)

return _fn
51 changes: 51 additions & 0 deletions chirho/robust/internals/predictive.py
Original file line number Diff line number Diff line change
Expand Up @@ -143,3 +143,54 @@ def forward(
**kwargs,
)[0]
return torch.logsumexp(log_weights, dim=0) - math.log(self.num_samples)


class PointLogPredictiveLikelihood(NMCLogPredictiveLikelihood):
def forward(
self, data: Point[T], *args: P.args, **kwargs: P.kwargs
) -> torch.Tensor:
if self.max_plate_nesting is None:
self.max_plate_nesting = guess_max_plate_nesting(
self.model, self.guide, *args, **kwargs
)

# Retrieve point estimate by sampling from the guide once
with pyro.poutine.trace() as guide_tr:
self.guide(*args, **kwargs)

point_estimate = {k: v["value"] for k, v in guide_tr.trace.nodes.items()}
model_at_point = condition(data=point_estimate)(self.model)

# Add plate to batch over many Monte Carlo draws from model
num_monte_carlo = data[next(iter(data))].shape[0] # type: ignore

def vectorize(fn):
def _fn(*args, **kwargs):
with pyro.plate(
"__monte_carlo_samples",
size=num_monte_carlo,
dim=-self.max_plate_nesting - 1,
):
return fn(*args, **kwargs)

return _fn

batched_model = condition(data=data)(vectorize(model_at_point))

# Compute log likelihood at each monte carlo sample
log_like_trace = pyro.poutine.trace(batched_model).get_trace(*args, **kwargs)
log_like_trace.compute_log_prob(lambda name, site: name in data.keys())
log_prob_at_datapoints = torch.zeros(num_monte_carlo)
for site_name in data.keys():
if log_like_trace.nodes[site_name]["log_prob"].dim() > 1:
# Sum probabilities over all dimensions except first batch dimension
dims_to_sum = list(
range(1, log_like_trace.nodes[site_name]["log_prob"].dim())
)
log_prob_at_datapoints += log_like_trace.nodes[site_name][
"log_prob"
].sum(dim=dims_to_sum)
else:
log_prob_at_datapoints += log_like_trace.nodes[site_name]["log_prob"]

return log_prob_at_datapoints
15 changes: 12 additions & 3 deletions chirho/robust/internals/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,11 +23,13 @@ def make_flatten_unflatten(

@make_flatten_unflatten.register(torch.Tensor)
def _make_flatten_unflatten_tensor(v: torch.Tensor):
batch_size = v.shape[0]

def flatten(v: torch.Tensor) -> torch.Tensor:
r"""
Flatten a tensor into a single vector.
"""
return v.flatten()
return v.reshape((batch_size, -1))

def unflatten(x: torch.Tensor) -> torch.Tensor:
r"""
Expand All @@ -40,11 +42,13 @@ def unflatten(x: torch.Tensor) -> torch.Tensor:

@make_flatten_unflatten.register(dict)
def _make_flatten_unflatten_dict(d: Dict[str, torch.Tensor]):
batch_size = next(iter(d.values())).shape[0]

def flatten(d: Dict[str, torch.Tensor]) -> torch.Tensor:
r"""
Flatten a dictionary of tensors into a single vector.
"""
return torch.cat([v.flatten() for k, v in d.items()])
return torch.hstack([v.reshape((batch_size, -1)) for k, v in d.items()])

def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]:
r"""
Expand All @@ -56,7 +60,12 @@ def unflatten(x: torch.Tensor) -> Dict[str, torch.Tensor]:
[
v_flat.reshape(v.shape)
for v, v_flat in zip(
d.values(), torch.split(x, [v.numel() for k, v in d.items()])
d.values(),
torch.split(
x,
[int(v.numel() / batch_size) for k, v in d.items()],
dim=1,
),
)
],
)
Expand Down
23 changes: 17 additions & 6 deletions chirho/robust/ops.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@ def influence_fn(
guide: Callable[P, Any],
functional: Optional[Functional[P, S]] = None,
**linearize_kwargs
) -> Callable[Concatenate[Point[T], P], S]:
) -> Callable[Concatenate[Point[T], bool, P], S]:
from chirho.robust.internals.linearize import linearize
from chirho.robust.internals.predictive import PredictiveFunctional
from chirho.robust.internals.utils import make_functional_call
Expand All @@ -39,10 +39,21 @@ def influence_fn(
target_params, func_target = make_functional_call(target)

@functools.wraps(target)
def _fn(point: Point[T], *args: P.args, **kwargs: P.kwargs) -> S:
param_eif = linearized(point, *args, **kwargs)
return torch.func.jvp(
lambda p: func_target(p, *args, **kwargs), (target_params,), (param_eif,)
)[1]
def _fn(
points: Point[T],
pointwise_influence: bool = False,
*args: P.args,
**kwargs: P.kwargs
) -> S:
param_eif = linearized(
points, pointwise_influence=pointwise_influence, *args, **kwargs
)
return torch.vmap(
lambda d: torch.func.jvp(
lambda p: func_target(p, *args, **kwargs), (target_params,), (d,)
)[1],
in_dims=0,
randomness="different",
)(param_eif)

return _fn
31 changes: 27 additions & 4 deletions tests/robust/robust_fixtures.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,26 +18,29 @@ class SimpleModel(PyroModule):
def __init__(
self,
link_fn: Callable[..., dist.Distribution] = lambda mu: dist.Normal(mu, 1.0),
p: int = 3,
):
super().__init__()
self.link_fn = link_fn
self.p = p

def forward(self):
a = pyro.sample("a", dist.Normal(0, 1))
with pyro.plate("data", 3, dim=-1):
with pyro.plate("data", self.p, dim=-1):
b = pyro.sample("b", dist.Normal(a, 1))
return pyro.sample("y", dist.Normal(b, 1))


class SimpleGuide(torch.nn.Module):
def __init__(self):
def __init__(self, p: int = 3):
super().__init__()
self.loc_a = torch.nn.Parameter(torch.rand(()))
self.loc_b = torch.nn.Parameter(torch.rand((3,)))
self.loc_b = torch.nn.Parameter(torch.rand((p,)))
self.p = p

def forward(self):
a = pyro.sample("a", dist.Normal(self.loc_a, 1))
with pyro.plate("data", 3, dim=-1):
with pyro.plate("data", self.p, dim=-1):
b = pyro.sample("b", dist.Normal(self.loc_b, 1))
return {"a": a, "b": b}

Expand All @@ -63,6 +66,15 @@ def gaussian_log_prob(params: ParamDict, data_point: Point[T], cov_mat) -> T:
).log_prob(data_point["x"])


def gaussian_log_prob_flattened(
params: torch.Tensor, data_point: Point[T], cov_mat: torch.Tensor
) -> torch.Tensor:
with pyro.validation_enabled(False):
return pyro.distributions.MultivariateNormal(
loc=params, covariance_matrix=cov_mat
).log_prob(data_point["x"])


class DataConditionedModel(PyroModule):
r"""
Helper class for conditioning on data.
Expand Down Expand Up @@ -228,3 +240,14 @@ def closed_form_ate_correction(
analytic_eif_at_test_pts = (A / pi_X - (1 - A) / (1 - pi_X)) * (Y - mu_X)
analytic_correction = analytic_eif_at_test_pts.mean()
return analytic_correction, analytic_eif_at_test_pts


def _make_fisher_jvp_score_formulation(
f: Callable, params: torch.Tensor, num_monte_carlo: int
) -> Callable:
def empirical_fisher_vp(v):
vnew = torch.func.jvp(f, (params,), (v / num_monte_carlo,))[1]
(_, vjpfunc) = torch.func.vjp(f, params)
return vjpfunc(vnew)[0]

return empirical_fisher_vp
Loading
Loading