Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finite Difference Baseline #508

Merged
merged 76 commits into from
Jan 19, 2024
Merged
Changes from 1 commit
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
38b6158
added robust folder
agrawalraj Nov 8, 2023
c09dcab
uncommited scratch work for log prob
agrawalraj Nov 9, 2023
21e31bf
untested variational log prob
agrawalraj Nov 9, 2023
faed235
uncomitted changes
agrawalraj Nov 13, 2023
fac98cd
uncomitted changes
agrawalraj Nov 16, 2023
4edcb5e
pair coding w/ eli
agrawalraj Nov 16, 2023
fe17403
added tests w/ Eli
agrawalraj Nov 17, 2023
b159687
eif
eb8680 Nov 17, 2023
33f4811
linting
agrawalraj Nov 18, 2023
8e171f4
moving test autograd to internals and deleted old utils file
agrawalraj Nov 20, 2023
93cc014
sketch influence implementation
eb8680 Nov 21, 2023
9bc704c
fix more args
eb8680 Nov 21, 2023
cedb818
ops file
eb8680 Nov 21, 2023
418f792
file
eb8680 Nov 21, 2023
f792ddf
format
eb8680 Nov 21, 2023
88a100b
lint
eb8680 Nov 21, 2023
94c2fc6
clean up influence and tests
eb8680 Nov 21, 2023
da0bc5c
make tests more generic
eb8680 Nov 22, 2023
4d027e4
guess max plate nesting
eb8680 Nov 22, 2023
e85e33f
linearize
eb8680 Nov 22, 2023
1734191
rename file
eb8680 Nov 22, 2023
f46556b
tensor flatten
eb8680 Nov 22, 2023
1abc5e0
predictive eif
eb8680 Nov 22, 2023
9c80b60
jvp type
eb8680 Nov 22, 2023
931da4f
reorganize files
eb8680 Nov 22, 2023
dc63f31
shrink test case
eb8680 Nov 22, 2023
be3bc8d
move guess_max_plate_nesting
eb8680 Nov 22, 2023
9ce164a
move cg solver to linearze
eb8680 Nov 22, 2023
81196d4
type alias
eb8680 Nov 22, 2023
30cb2e7
test_ops
eb8680 Nov 22, 2023
21cf2d7
basic cg tests
eb8680 Nov 22, 2023
720661f
remove failing test case
eb8680 Nov 22, 2023
91833da
format
eb8680 Nov 22, 2023
548069a
move paramdict up
eb8680 Nov 22, 2023
12b22c0
remove obsolete test files
eb8680 Nov 22, 2023
d2bbf9d
Merge branch 'master' into staging-robust
eb8680 Nov 22, 2023
3b72bb0
add empty handlers
eb8680 Nov 22, 2023
89d9f6b
add chirho.robust to docs
eb8680 Nov 22, 2023
7582c22
fix memory leak in tests
eb8680 Nov 27, 2023
82c23e8
make typing compatible with python 3.8
eb8680 Nov 27, 2023
e08d9d6
typing_extensions
eb8680 Nov 27, 2023
22eae09
add branch to ci
eb8680 Nov 27, 2023
d0014db
predictive
eb8680 Nov 27, 2023
e5342dc
remove imprecise annotation
eb8680 Nov 27, 2023
be13ac5
Merge branch 'master' into staging-robust
SamWitty Nov 28, 2023
c5fe64b
Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)
agrawalraj Dec 6, 2023
117d645
Add upper bound on number of CG steps (#404)
eb8680 Dec 7, 2023
8fe1b25
fixed test for non-symmetric matrix (#437)
agrawalraj Dec 7, 2023
3f0c83d
Make `NMCLogPredictiveLikelihood` seeded (#408)
agrawalraj Dec 8, 2023
4d41807
Use Hessian formulation of Fisher information in `make_empirical_fish…
agrawalraj Dec 8, 2023
2e01b7b
Add new `SimpleModel` and `SimpleGuide` (#440)
agrawalraj Dec 8, 2023
538cef8
Batching in `linearize` and `influence` (#465)
agrawalraj Dec 22, 2023
6bba70b
batched cg (#466)
agrawalraj Dec 22, 2023
f143d3a
One step correction implemented (#467)
agrawalraj Dec 22, 2023
878eb0d
Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLog…
eb8680 Jan 2, 2024
3cfe319
Added documentation for `chirho.robust` (#470)
agrawalraj Jan 2, 2024
5d77fe0
Make functional argument to influence_fn required (#487)
eb8680 Jan 9, 2024
013d518
Remove guide argument from `influence_fn` and `linearize` (#489)
eb8680 Jan 9, 2024
c4346c8
Make influence_fn a higher-order Functional (#492)
eb8680 Jan 11, 2024
9207e3e
Add full corrected one step estimator (#476)
SamWitty Jan 12, 2024
ca916cd
Merge branch 'master' into staging-robust
eb8680 Jan 12, 2024
a7875c6
add abstractions and simple temp scratch to test with squared unit no…
azane Jan 12, 2024
ad519be
removes old scratch notebook
azane Jan 12, 2024
127a4a4
Merge branch 'staging-robust' into az-influence-finite-difference-2
azane Jan 12, 2024
1efe6ea
gets squared density running under abstraction that couples functiona…
azane Jan 12, 2024
44785d8
gets quad and mc approximations to match, vectorization hacky.
azane Jan 12, 2024
5a11a7a
Merge branch 'staging-robust-icml' into az-influence-finite-difference-2
azane Jan 16, 2024
31cc9ac
adds plotting and comparative to analytic.
azane Jan 16, 2024
f867f2a
adds scratch experiment comparing squared density analytic vs fd appr…
azane Jan 17, 2024
7f10667
fixes dataset splitting, breaks analytic eif
azane Jan 17, 2024
094562a
unfixes an incorrect fix, working now.
azane Jan 17, 2024
0556543
Merge branch 'staging-robust-icml' into az-influence-finite-difference-2
azane Jan 17, 2024
327779a
Merge branch 'staging-robust-icml' into az-influence-finite-difference-2
azane Jan 18, 2024
3e33dc9
refactors finite difference machinery to fit experimental specs.
azane Jan 18, 2024
b21a882
switches to existing rng seed context manager.
azane Jan 19, 2024
79989f9
reverts back to what turns out to be a slightly different seeding con…
azane Jan 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
130 changes: 70 additions & 60 deletions chirho/robust/internals.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,10 @@
from typing import ParamSpec, Callable, TypeVar, Optional
from typing import ParamSpec, Callable, TypeVar, Optional, Dict, List
import torch
from pyro.infer import Predictive
from pyro.infer import Trace_ELBO
from pyro.infer.elbo import ELBOModule
from pyro.infer.importance import vectorized_importance_weights
from pyro.poutine import mask, replay, trace
from pyro.poutine import block, replay, trace, mask

P = ParamSpec("P")
Q = ParamSpec("Q")
@@ -15,19 +15,47 @@
Guide = Callable[P, Optional[T | Point[T]]]


# guide should hide obs_names sites
def _shuffle_dict(d: dict[str, T]):
"""
Shuffle values of a dictionary in first batch dimension
"""
return {k: v[torch.randperm(v.shape[0])] for k, v in d.items()}


# Need to add vectorize function from vectorized_importance_weights


# Issue: gradients detached in predictives
def vectorized_variational_log_prob(
model: Callable[P, T], guide: Guide[P, T], X: Point, *args, **kwargs
model: Callable[P, T],
guide: Guide[P, T],
trace_predictive: Dict,
obs_names: List[str],
# num_particles: int = 1, # TODO: support this next
*args,
**kwargs
):
guide_trace = trace(guide).get_trace(*args, **kwargs)
model_trace = trace(replay(model, guide_trace)).get_trace(*args, **kwargs)
log_probs = dict()
for site_name, site_val in X.items():
"""
See eq. 3 in http://approximateinference.org/2017/accepted/TangRanganath2017.pdf
"""
latent_params_trace = _shuffle_dict(
{k: v.clone() for k, v in trace_predictive.items() if k not in obs_names}
)
obs_vars_trace = {
k: v.clone().detach() for k, v in trace_predictive.items() if k in obs_names
}
import pdb

pdb.set_trace()
model_trace = trace(replay(model, latent_params_trace)).get_trace(*args, **kwargs)

N_samples = next(iter(latent_params_trace.values())).shape[0]

log_probs = torch.zeros(N_samples)
for site_name, site_val in obs_vars_trace.items():
site = model_trace.nodes[site_name]
assert site["type"] == "sample"
log_probs[site_name] = site["fn"].log_prob(site_val)
log_probs += site["fn"].log_prob(site_val)
return log_probs


@@ -61,64 +89,46 @@ def log_prob(self, X: Point, *args, **kwargs) -> torch.Tensor:
pass


def log_likelihood_fn(flat_theta: torch.tensor, X: Dict[str, torch.Tensor]):
n_monte_carlo = X[next(iter(X))].shape[0]
theta = _unflatten_dict(flat_theta, theta_hat)
model_at_theta = condition(data=theta)(DataConditionedModel(model))
log_like_trace = pyro.poutine.trace(model_at_theta).get_trace(X)
log_like_trace.compute_log_prob()
log_prob_at_datapoints = torch.zeros(n_monte_carlo)
for name in obs_names:
log_prob_at_datapoints += log_like_trace.nodes[name]["log_prob"]
return log_prob_at_datapoints


# For continous latents, vectorized importance weights
# https://docs.pyro.ai/en/stable/inference_algos.html#pyro.infer.importance.vectorized_importance_weights

# Predictive(model, guide)

if __name__ == "__main__":
import pyro
import pyro.distributions as dist

import pyro
import pyro.distributions as dist


# Create simple pyro model
def model(x: torch.Tensor) -> torch.Tensor:
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.Normal(0, 1))
with pyro.plate("data", x.shape[0]):
y = a * x + b
return pyro.sample("y", dist.Normal(y, 1))


# Create guide
guide_normal = pyro.infer.autoguide.AutoNormal(model)


def fixed_guide(x: torch.Tensor) -> None:
pyro.sample("a", dist.Delta(torch.tensor(1.0)))
pyro.sample("b", dist.Delta(torch.tensor(1.0)))


# Create predictive
predictive = Predictive(model, guide=fixed_guide, num_samples=1000)

samps = predictive(torch.tensor([1.0]))

# Create elbo loss
elbo = pyro.infer.Trace_ELBO(num_particles=100)(model, guide=guide_normal)


torch.autograd(elbo(torch.tensor([1.0])), elbo.parameters())

torch.autograd.functional.jacobian(
elbo,
torch.tensor([1.0, 2.0]),
dict(elbo.named_parameters())["guide.locs.a_unconstrained"],
)

x0 = torch.tensor([1.0, 2.0], requires_grad=False)

elbo(x0)

x1 = torch.tensor([[1.0, 2.0], [2.0, 3.0]], requires_grad=True)


vectorized_importance_weights(
model, guide_normal, x=x0, max_plate_nesting=4, num_samples=10000
)[0].mean()
# Create simple pyro model
def model():
a = pyro.sample("a", dist.Normal(0, 1))
b = pyro.sample("b", dist.Normal(0, 1))
return pyro.sample("y", dist.Normal(a + b, 1))

# Create guide on latents a and b
guide = pyro.infer.autoguide.AutoNormal(block(model, hide=["y"]))
# with pyro.poutine.trace() as tr:
# guide()
# print(tr.trace.nodes.keys())
# Create predictive
predictive = Predictive(model, guide=guide, num_samples=100)
# with pyro.poutine.trace() as tr:
X = predictive()

torch.stack([torch.zeros(3), torch.zeros(3)])
vectorized_variational_log_prob(model, guide, X, ["y"])

# print(X)
# import pdb

elbo.parameters()
# pdb.set_trace()