Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Replace some torch.vmap usage with a hand-vectorized BatchedNMCLogPredictiveLikelihood #473

Merged
merged 48 commits into from
Jan 2, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
48 commits
Select commit Hold shift + click to select a range
71ac10b
sketch batched nmc lpd
eb8680 Dec 29, 2023
9f233a7
nits
eb8680 Dec 29, 2023
1dbd25a
fix type
eb8680 Dec 29, 2023
748292a
format
eb8680 Dec 29, 2023
c33fecf
comment
eb8680 Dec 29, 2023
aa610f2
comment
eb8680 Dec 29, 2023
1e9b0c6
comment
eb8680 Dec 29, 2023
0d61de6
typo
eb8680 Dec 29, 2023
f2ac532
typo
eb8680 Dec 29, 2023
90e76b7
add condition to help guarantee idempotence
eb8680 Dec 29, 2023
3ce97f7
simplify edge case
eb8680 Dec 29, 2023
7098012
simplify plate_name
eb8680 Dec 29, 2023
27d1d46
simplify batchedobservation logic
eb8680 Dec 29, 2023
c9c7746
factorize
eb8680 Dec 29, 2023
79823b3
simplify batched
eb8680 Dec 29, 2023
db0e404
reorder
eb8680 Dec 29, 2023
f8e80cb
comment
eb8680 Dec 29, 2023
e7bdaa7
remove plate_names
eb8680 Dec 29, 2023
223bbe0
types
eb8680 Dec 29, 2023
2840d92
formatting and type
eb8680 Dec 29, 2023
e38d34f
move unbind to utils
eb8680 Dec 29, 2023
618e2b3
remove max_plate_nesting arg from get_traces
eb8680 Dec 29, 2023
f6484ea
comment
eb8680 Dec 29, 2023
f070f5f
nit
eb8680 Dec 29, 2023
317d8b9
move get_importance_traces to utils
eb8680 Dec 30, 2023
af05a2f
fix types
eb8680 Dec 30, 2023
f2c1006
generic obs type
eb8680 Dec 30, 2023
04a7575
lint
eb8680 Dec 30, 2023
963d601
format
eb8680 Dec 30, 2023
e14875d
handle observe in batchedobservations
eb8680 Dec 30, 2023
85bf548
event dim
eb8680 Dec 30, 2023
4d8d413
move batching handlers to utils
eb8680 Dec 30, 2023
aef852c
replace 2/3 vmaps, tests pass
eb8680 Dec 30, 2023
f8f1e51
remove dead code
eb8680 Dec 30, 2023
c6e5760
format
eb8680 Dec 30, 2023
403e630
name args
eb8680 Dec 30, 2023
5dbd5df
lint
eb8680 Dec 30, 2023
b2ce754
shuffle code
eb8680 Dec 30, 2023
1a40fb2
try an extra optimization in batchedlatents
eb8680 Dec 30, 2023
b218eda
add another optimization
eb8680 Dec 30, 2023
c08eda9
undo changes to test
eb8680 Dec 30, 2023
caedac8
remove inplace adds
eb8680 Jan 2, 2024
13344d7
add performance test showing speedup
eb8680 Jan 2, 2024
9b3e962
document internal helpers
eb8680 Jan 2, 2024
8061119
batch latents test
eb8680 Jan 2, 2024
fb2b315
move batch handlers to predictive
eb8680 Jan 2, 2024
e33756d
add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel
eb8680 Jan 2, 2024
6d2d3e6
use bind_leftmost_dim in log prob
eb8680 Jan 2, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
8 changes: 5 additions & 3 deletions chirho/observational/handlers/condition.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import Callable, Generic, Hashable, Mapping, TypeVar, Union
from typing import Callable, Generic, Mapping, TypeVar, Union

import pyro
import torch

from chirho.observational.internals import ObserveNameMessenger
from chirho.observational.ops import AtomicObservation, observe
from chirho.observational.ops import Observation, observe

T = TypeVar("T")
R = Union[float, torch.Tensor]
Expand Down Expand Up @@ -62,7 +62,9 @@ class Observations(Generic[T], ObserveNameMessenger):
a richer set of observational data types and enables counterfactual inference.
"""

def __init__(self, data: Mapping[Hashable, AtomicObservation[T]]):
data: Mapping[str, Observation[T]]

def __init__(self, data: Mapping[str, Observation[T]]):
self.data = data
super().__init__()

Expand Down
28 changes: 8 additions & 20 deletions chirho/robust/internals/linearize.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@
import torch
from typing_extensions import Concatenate, ParamSpec

from chirho.robust.internals.predictive import NMCLogPredictiveLikelihood
from chirho.robust.internals.predictive import BatchedNMCLogPredictiveLikelihood
from chirho.robust.internals.utils import (
ParamDict,
make_flatten_unflatten,
Expand Down Expand Up @@ -92,23 +92,17 @@ def f_Ax_flat(v: torch.Tensor) -> torch.Tensor:


def make_empirical_fisher_vp(
func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor],
batched_func_log_prob: Callable[Concatenate[ParamDict, Point[T], P], torch.Tensor],
log_prob_params: ParamDict,
data: Point[T],
*args: P.args,
**kwargs: P.kwargs,
) -> Callable[[ParamDict], ParamDict]:
batched_func_log_prob: Callable[[ParamDict, Point[T]], torch.Tensor] = torch.vmap(
lambda p, data: func_log_prob(p, data, *args, **kwargs),
in_dims=(None, 0),
randomness="different",
)

N = data[next(iter(data))].shape[0] # type: ignore
mean_vector = 1 / N * torch.ones(N)

def bound_batched_func_log_prob(params: ParamDict) -> torch.Tensor:
return batched_func_log_prob(params, data)
return batched_func_log_prob(params, data, *args, **kwargs)

def _empirical_fisher_vp(v: ParamDict) -> ParamDict:
def jvp_fn(log_prob_params: ParamDict) -> torch.Tensor:
Expand Down Expand Up @@ -145,12 +139,11 @@ def linearize(
num_samples=num_samples_outer,
parallel=True,
)
predictive_params, func_predictive = make_functional_call(predictive)

log_prob = NMCLogPredictiveLikelihood(
batched_log_prob = BatchedNMCLogPredictiveLikelihood(
model, guide, num_samples=num_samples_inner, max_plate_nesting=max_plate_nesting
)
log_prob_params, func_log_prob = make_functional_call(log_prob)
log_prob_params, batched_func_log_prob = make_functional_call(batched_log_prob)
log_prob_params_numel: int = sum(p.numel() for p in log_prob_params.values())
if cg_iters is None:
cg_iters = log_prob_params_numel
Expand All @@ -166,23 +159,18 @@ def _fn(
**kwargs: P.kwargs,
) -> ParamDict:
with torch.no_grad():
data: Point[T] = func_predictive(predictive_params, *args, **kwargs)
data: Point[T] = predictive(*args, **kwargs)
data = {k: data[k] for k in points.keys()}
fvp = make_empirical_fisher_vp(
func_log_prob, log_prob_params, data, *args, **kwargs
batched_func_log_prob, log_prob_params, data, *args, **kwargs
)
pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp)
pinned_fvp_batched = torch.func.vmap(
lambda v: pinned_fvp(v), randomness="different"
)
batched_func_log_prob = torch.vmap(
lambda p, data: func_log_prob(p, data, *args, **kwargs),
in_dims=(None, 0),
randomness="different",
)

def bound_batched_func_log_prob(p: ParamDict) -> torch.Tensor:
return batched_func_log_prob(p, points)
return batched_func_log_prob(p, points, *args, **kwargs)

if pointwise_influence:
score_fn = torch.func.jacrev(bound_batched_func_log_prob)
Expand Down
Loading
Loading