Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Finite Difference Baseline #508

Merged
merged 76 commits into from
Jan 19, 2024
Merged
Changes from 1 commit
Commits
Show all changes
76 commits
Select commit Hold shift + click to select a range
38b6158
added robust folder
agrawalraj Nov 8, 2023
c09dcab
uncommited scratch work for log prob
agrawalraj Nov 9, 2023
21e31bf
untested variational log prob
agrawalraj Nov 9, 2023
faed235
uncomitted changes
agrawalraj Nov 13, 2023
fac98cd
uncomitted changes
agrawalraj Nov 16, 2023
4edcb5e
pair coding w/ eli
agrawalraj Nov 16, 2023
fe17403
added tests w/ Eli
agrawalraj Nov 17, 2023
b159687
eif
eb8680 Nov 17, 2023
33f4811
linting
agrawalraj Nov 18, 2023
8e171f4
moving test autograd to internals and deleted old utils file
agrawalraj Nov 20, 2023
93cc014
sketch influence implementation
eb8680 Nov 21, 2023
9bc704c
fix more args
eb8680 Nov 21, 2023
cedb818
ops file
eb8680 Nov 21, 2023
418f792
file
eb8680 Nov 21, 2023
f792ddf
format
eb8680 Nov 21, 2023
88a100b
lint
eb8680 Nov 21, 2023
94c2fc6
clean up influence and tests
eb8680 Nov 21, 2023
da0bc5c
make tests more generic
eb8680 Nov 22, 2023
4d027e4
guess max plate nesting
eb8680 Nov 22, 2023
e85e33f
linearize
eb8680 Nov 22, 2023
1734191
rename file
eb8680 Nov 22, 2023
f46556b
tensor flatten
eb8680 Nov 22, 2023
1abc5e0
predictive eif
eb8680 Nov 22, 2023
9c80b60
jvp type
eb8680 Nov 22, 2023
931da4f
reorganize files
eb8680 Nov 22, 2023
dc63f31
shrink test case
eb8680 Nov 22, 2023
be3bc8d
move guess_max_plate_nesting
eb8680 Nov 22, 2023
9ce164a
move cg solver to linearze
eb8680 Nov 22, 2023
81196d4
type alias
eb8680 Nov 22, 2023
30cb2e7
test_ops
eb8680 Nov 22, 2023
21cf2d7
basic cg tests
eb8680 Nov 22, 2023
720661f
remove failing test case
eb8680 Nov 22, 2023
91833da
format
eb8680 Nov 22, 2023
548069a
move paramdict up
eb8680 Nov 22, 2023
12b22c0
remove obsolete test files
eb8680 Nov 22, 2023
d2bbf9d
Merge branch 'master' into staging-robust
eb8680 Nov 22, 2023
3b72bb0
add empty handlers
eb8680 Nov 22, 2023
89d9f6b
add chirho.robust to docs
eb8680 Nov 22, 2023
7582c22
fix memory leak in tests
eb8680 Nov 27, 2023
82c23e8
make typing compatible with python 3.8
eb8680 Nov 27, 2023
e08d9d6
typing_extensions
eb8680 Nov 27, 2023
22eae09
add branch to ci
eb8680 Nov 27, 2023
d0014db
predictive
eb8680 Nov 27, 2023
e5342dc
remove imprecise annotation
eb8680 Nov 27, 2023
be13ac5
Merge branch 'master' into staging-robust
SamWitty Nov 28, 2023
c5fe64b
Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)
agrawalraj Dec 6, 2023
117d645
Add upper bound on number of CG steps (#404)
eb8680 Dec 7, 2023
8fe1b25
fixed test for non-symmetric matrix (#437)
agrawalraj Dec 7, 2023
3f0c83d
Make `NMCLogPredictiveLikelihood` seeded (#408)
agrawalraj Dec 8, 2023
4d41807
Use Hessian formulation of Fisher information in `make_empirical_fish…
agrawalraj Dec 8, 2023
2e01b7b
Add new `SimpleModel` and `SimpleGuide` (#440)
agrawalraj Dec 8, 2023
538cef8
Batching in `linearize` and `influence` (#465)
agrawalraj Dec 22, 2023
6bba70b
batched cg (#466)
agrawalraj Dec 22, 2023
f143d3a
One step correction implemented (#467)
agrawalraj Dec 22, 2023
878eb0d
Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLog…
eb8680 Jan 2, 2024
3cfe319
Added documentation for `chirho.robust` (#470)
agrawalraj Jan 2, 2024
5d77fe0
Make functional argument to influence_fn required (#487)
eb8680 Jan 9, 2024
013d518
Remove guide argument from `influence_fn` and `linearize` (#489)
eb8680 Jan 9, 2024
c4346c8
Make influence_fn a higher-order Functional (#492)
eb8680 Jan 11, 2024
9207e3e
Add full corrected one step estimator (#476)
SamWitty Jan 12, 2024
ca916cd
Merge branch 'master' into staging-robust
eb8680 Jan 12, 2024
a7875c6
add abstractions and simple temp scratch to test with squared unit no…
azane Jan 12, 2024
ad519be
removes old scratch notebook
azane Jan 12, 2024
127a4a4
Merge branch 'staging-robust' into az-influence-finite-difference-2
azane Jan 12, 2024
1efe6ea
gets squared density running under abstraction that couples functiona…
azane Jan 12, 2024
44785d8
gets quad and mc approximations to match, vectorization hacky.
azane Jan 12, 2024
5a11a7a
Merge branch 'staging-robust-icml' into az-influence-finite-difference-2
azane Jan 16, 2024
31cc9ac
adds plotting and comparative to analytic.
azane Jan 16, 2024
f867f2a
adds scratch experiment comparing squared density analytic vs fd appr…
azane Jan 17, 2024
7f10667
fixes dataset splitting, breaks analytic eif
azane Jan 17, 2024
094562a
unfixes an incorrect fix, working now.
azane Jan 17, 2024
0556543
Merge branch 'staging-robust-icml' into az-influence-finite-difference-2
azane Jan 17, 2024
327779a
Merge branch 'staging-robust-icml' into az-influence-finite-difference-2
azane Jan 18, 2024
3e33dc9
refactors finite difference machinery to fit experimental specs.
azane Jan 18, 2024
b21a882
switches to existing rng seed context manager.
azane Jan 19, 2024
79989f9
reverts back to what turns out to be a slightly different seeding con…
azane Jan 19, 2024
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Prev Previous commit
Next Next commit
adds scratch experiment comparing squared density analytic vs fd appr…
…ox across various epsilon lambdas
azane committed Jan 17, 2024
commit f867f2a1759d56cbd23e9c54d9e5bbf1f9e6648f
12 changes: 9 additions & 3 deletions chirho/robust/handlers/fd_model.py
Original file line number Diff line number Diff line change
@@ -4,9 +4,13 @@
from typing import Dict, Optional
from contextlib import contextmanager
from chirho.robust.ops import Functional, Point, T
import numpy as np


class ModelWithMarginalDensity(torch.nn.Module):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def density(self, *args, **kwargs):
# TODO this can probably default to using BatchedNMCLogMarginalLikelihood applied to self,
# but providing here to avail of analytic densities. Or have a constructor that takes a
@@ -76,11 +80,13 @@ def kernel(self) -> ModelWithMarginalDensity:
"""
raise NotImplementedError()

def __init__(self, default_kernel_point: Dict, default_eps=0., default_lambda=0.1):
super().__init__()
def __init__(self, default_kernel_point: Dict, *args, default_eps=0., default_lambda=0.1, **kwargs):
super().__init__(*args, **kwargs)
self._eps = default_eps
self._lambda = default_lambda
self._kernel_point = default_kernel_point
# TODO don't assume .shape[-1]
self.ndims = np.sum([v.shape[-1] for v in self._kernel_point.values()])

@property
def mixture_weights(self):
@@ -155,7 +161,7 @@ def _influence_fn(*args, **kwargs):
with model.set_eps(eps), model.set_lambda(lambda_), model.set_kernel_point(kernel_point):
psi_p_eps = model.functional(*args, **kwargs)

eif_vals.append(-(psi_p_eps - psi_p) / eps)
eif_vals.append((psi_p_eps - psi_p) / eps)
return eif_vals

return _influence_fn
39 changes: 21 additions & 18 deletions docs/source/robust_fd/squared_normal_density.py
Original file line number Diff line number Diff line change
@@ -6,16 +6,11 @@
from scipy.integrate import nquad
import numpy as np

# TODO after putting this together, a mixin model would be more appropriate, as we still
# want explicit coupling between models and functionals but it can be M:M. I.e. mixin the
# functional that could apply to a number of models, and/or mixin the model that could work
# with a number of functionals.

class MultivariateNormalwDensity(ModelWithMarginalDensity):

class FDMultivariateNormal(ModelWithMarginalDensity):

def __init__(self, mean, cov):
super().__init__()
def __init__(self, mean, cov, *args, **kwargs):
super().__init__(*args, **kwargs)

self.mean = mean
self.cov = cov
@@ -27,27 +22,31 @@ def forward(self):
return pyro.sample("x", dist.MultivariateNormal(self.mean, self.cov))


class _ExpectedNormalDensity(FDModelFunctionalDensity):
class NormalKernel(FDModelFunctionalDensity):

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

@property
def kernel(self):
try:
mean = self._kernel_point['x']
except TypeError as e:
raise
return FDMultivariateNormal(mean, torch.eye(self.ndims) * self._lambda)
# TODO agnostic to names.
mean = self._kernel_point['x']
return MultivariateNormalwDensity(mean, torch.eye(self.ndims) * self._lambda)


class PerturbableNormal(FDModelFunctionalDensity):

def __init__(self, *args, mean, cov, **kwargs):
super().__init__(*args, **kwargs)

self.ndims = mean.shape[-1]
self.model = FDMultivariateNormal(mean, cov)
self.model = MultivariateNormalwDensity(mean, cov)

self.mean = mean
self.cov = cov


class ExpectedNormalDensityQuad(_ExpectedNormalDensity):
class ExpectedDensityQuadFunctional(FDModelFunctionalDensity):
"""
Compute the squared normal density using quadrature.
"""
@@ -57,13 +56,16 @@ def __init__(self, *args, **kwargs):

def functional(self):
def integrand(*args):
# TODO agnostic to kwarg names.
model_kwargs = kernel_kwargs = dict(x=np.array(args))
return self.density(model_kwargs, kernel_kwargs) ** 2

return nquad(integrand, [[-np.inf, np.inf]] * self.mean.shape[-1])[0]
ndim = self._kernel_point['x'].shape[-1]

return nquad(integrand, [[-np.inf, np.inf]] * ndim)[0]


class ExpectedNormalDensityMC(_ExpectedNormalDensity):
class ExpectedDensityMCFunctional(FDModelFunctionalDensity):
"""
Compute the squared normal density using Monte Carlo.
"""
@@ -72,6 +74,7 @@ def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)

def functional(self, nmc=1000):
# TODO agnostic to kwarg names
with pyro.plate('samples', nmc):
points = self()
return torch.mean(self.density(model_kwargs=dict(x=points), kernel_kwargs=dict(x=points)))
4 changes: 3 additions & 1 deletion docs/source/robust_fd_scratch.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
raise NotImplementedError()

from robust_fd.squared_normal_density import ExpectedNormalDensityQuad, ExpectedNormalDensityMC, _ExpectedNormalDensity
from chirho.robust.handlers.fd_model import fd_influence_fn
import numpy as np
@@ -91,7 +93,7 @@ def compute_analytic_eif(model: _ExpectedNormalDensity, points):
funcval = model.functional()
density = model.density(points, points)

return 2. * (funcval - density)
return 2. * (density - funcval)


analytic_eif = compute_analytic_eif(end_quad, points).numpy()
Loading