Skip to content

Commit

Permalink
Implement TMLE estimator using Influence Functions (#484)
Browse files Browse the repository at this point in the history
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* progress on tmle

* placeholder test

* more progress on TMLE

* more progress, still need to refactor

* progress on variational tmle

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* more progress on tmle

* really resolved merge conflicts

* more progress, still a bit stuck on functional tensors

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* update tmle signature and remove unused imports

* make tmle signature consistent with one-step

* lint

* progress

* pair program still issues

* debugging still

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

* more attempts

* added scipy optimize :(

* more progress

* more progress

* working end-to-end tmle

* remove comment

* revert changes

* update tests and defaults

* lint

* playing with tmle performance

* more tweaks

* pulled out influence computation and changed loss

* finally got tmle working

* revert test

* added placeholder for passing in influence_fn_estimator

* analytic influence for example

* lint

* fix estimator

* fix tests

* lint

* notebook

* bump notebook

* use torchopt

* add torchopt

* rerun tmle notebook with effect = 1

* lint

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
  • Loading branch information
5 people authored Jan 25, 2024
1 parent bd63526 commit da3d5d5
Show file tree
Hide file tree
Showing 4 changed files with 1,170 additions and 5 deletions.
214 changes: 213 additions & 1 deletion chirho/robust/handlers/estimators.py
Original file line number Diff line number Diff line change
@@ -1,18 +1,230 @@
import copy
import warnings
from typing import Any, Callable, TypeVar

import torch
import torchopt
from typing_extensions import ParamSpec

from chirho.robust.handlers.predictive import PredictiveFunctional
from chirho.robust.internals.utils import make_functional_call
from chirho.robust.ops import Functional, Point, influence_fn

P = ParamSpec("P")
S = TypeVar("S")
T = TypeVar("T")


def tmle_scipy_optimize_wrapper(
packed_influence, log_jitter: float = 1e-6
) -> torch.Tensor:
import numpy as np
import scipy
from scipy.optimize import LinearConstraint

# Turn things into numpy. This makes us sad... :(
D = packed_influence.detach().numpy()

N, L = D.shape[0], D.shape[1]

def loss(epsilon):
correction = 1 + D.dot(epsilon)

return -np.sum(np.log(np.maximum(correction, log_jitter)))

positive_density_constraint = LinearConstraint(
D, -1 * np.ones(N), np.inf * np.ones(N)
)

epsilon_solve = scipy.optimize.minimize(
loss, np.zeros(L, dtype=D.dtype), constraints=positive_density_constraint
)

if not epsilon_solve.success:
warnings.warn("TMLE optimization did not converge.", RuntimeWarning)

# Convert epsilon back to torch. This makes us happy... :)
packed_epsilon = torch.tensor(epsilon_solve.x, dtype=packed_influence.dtype)

return packed_epsilon


# TODO: revert influence_estimator to influence_fn and use handlers for influence_fn
def tmle(
functional: Functional[P, S],
test_point: Point,
learning_rate: float = 1e-5,
n_grad_steps: int = 100,
n_tmle_steps: int = 1,
num_nmc_samples: int = 1000,
num_grad_samples: int = 1000,
log_jitter: float = 1e-6,
verbose: bool = False,
influence_estimator: Callable[
[Functional[P, S], Point[T]], Functional[P, S]
] = influence_fn,
**influence_kwargs,
) -> Functional[P, S]:
from chirho.robust.internals.nmc import BatchedNMCLogMarginalLikelihood

def _solve_epsilon(prev_model: torch.nn.Module, *args, **kwargs) -> torch.Tensor:
# find epsilon that minimizes the corrected density on test data

influence_at_test = influence_estimator(
functional, test_point, **influence_kwargs
)(prev_model)(*args, **kwargs)

flat_influence_at_test, _ = torch.utils._pytree.tree_flatten(influence_at_test)

N = flat_influence_at_test[0].shape[0]

packed_influence_at_test = torch.concatenate(
[i.reshape(N, -1) for i in flat_influence_at_test]
)

packed_epsilon = tmle_scipy_optimize_wrapper(packed_influence_at_test)

return packed_epsilon

def _solve_model_projection(
packed_epsilon: torch.Tensor,
prev_model: torch.nn.Module,
*args,
**kwargs,
) -> torch.nn.Module:
prev_params, functional_model = make_functional_call(
PredictiveFunctional(prev_model, num_samples=num_grad_samples)
)
prev_params = {k: v.detach() for k, v in prev_params.items()}

# Sample data from the model. Note that we only sample once during projection.
data = {
k: v
for k, v in functional_model(prev_params, *args, **kwargs).items()
if k in test_point
}

batched_log_prob: torch.nn.Module = BatchedNMCLogMarginalLikelihood(
prev_model, num_samples=num_nmc_samples
)

_, log_p_phi = make_functional_call(batched_log_prob)

influence_at_data = influence_estimator(functional, data, **influence_kwargs)(
prev_model
)(*args, **kwargs)
flat_influence_at_data, _ = torch.utils._pytree.tree_flatten(influence_at_data)
N_x = flat_influence_at_data[0].shape[0]

packed_influence_at_data = torch.concatenate(
[i.reshape(N_x, -1) for i in flat_influence_at_data]
).detach()

log_likelihood_correction = torch.log(
torch.maximum(
1 + packed_influence_at_data.mv(packed_epsilon),
torch.tensor(log_jitter),
)
).detach()
if verbose:
influence_at_test = influence_estimator(
functional, test_point, **influence_kwargs
)(prev_model)(*args, **kwargs)
flat_influence_at_test, _ = torch.utils._pytree.tree_flatten(
influence_at_test
)
N = flat_influence_at_test[0].shape[0]

packed_influence_at_test = torch.concatenate(
[i.reshape(N, -1) for i in flat_influence_at_test]
).detach()

log_likelihood_correction_at_test = torch.log(
torch.maximum(
1 + packed_influence_at_test.mv(packed_epsilon),
torch.tensor(log_jitter),
)
)

print("previous log prob at test", log_p_phi(prev_params, test_point).sum())
print(
"new log prob at test",
(
log_p_phi(prev_params, test_point)
+ log_likelihood_correction_at_test
).sum(),
)

log_p_epsilon_at_data = (
log_likelihood_correction + log_p_phi(prev_params, data)
).detach()

def loss(new_params):
log_p_phi_at_data = log_p_phi(new_params, data)
return torch.sum((log_p_phi_at_data - log_p_epsilon_at_data) ** 2)

grad_fn = torch.func.grad(loss)

new_params = {
k: v.clone().detach().requires_grad_(True) for k, v in prev_params.items()
}

optimizer = torchopt.adam(lr=learning_rate)

optimizer_state = optimizer.init(new_params)

for i in range(n_grad_steps):
grad = grad_fn(new_params)
if verbose and i % 100 == 0:
print(f"inner_iteration_{i}_loss", loss(new_params))
for parameter_name, parameter in prev_model.named_parameters():
parameter.data = new_params[f"model.{parameter_name}"]

estimate = functional(prev_model)(*args, **kwargs)
assert isinstance(estimate, torch.Tensor)
print(
f"inner_iteration_{i}_estimate",
estimate.detach().item(),
)
updates, optimizer_state = optimizer.update(
grad, optimizer_state, inplace=False
)
new_params = torchopt.apply_updates(new_params, updates)

for parameter_name, parameter in prev_model.named_parameters():
parameter.data = new_params[f"model.{parameter_name}"]

return prev_model

def _corrected_functional(*models: Callable[P, Any]) -> Callable[P, S]:
assert len(models) == 1
model = models[0]

assert isinstance(model, torch.nn.Module)

def _estimator(*args, **kwargs) -> S:
tmle_model = copy.deepcopy(model)

for _ in range(n_tmle_steps):
packed_epsilon = _solve_epsilon(tmle_model, *args, **kwargs)

tmle_model = _solve_model_projection(
packed_epsilon, tmle_model, *args, **kwargs
)
return functional(tmle_model)(*args, **kwargs)

return _estimator

return _corrected_functional


# TODO: revert influence_estimator to influence_fn and use handlers for influence_fn
def one_step_corrected_estimator(
functional: Functional[P, S],
*test_points: Point[T],
influence_estimator: Callable[
[Functional[P, S], Point[T]], Functional[P, S]
] = influence_fn,
**influence_kwargs,
) -> Functional[P, S]:
"""
Expand All @@ -30,7 +242,7 @@ def one_step_corrected_estimator(
"""
influence_kwargs_one_step = influence_kwargs.copy()
influence_kwargs_one_step["pointwise_influence"] = False
eif_fn = influence_fn(functional, *test_points, **influence_kwargs_one_step)
eif_fn = influence_estimator(functional, *test_points, **influence_kwargs_one_step)

def _corrected_functional(*model: Callable[P, Any]) -> Callable[P, S]:
plug_in_estimator = functional(*model)
Expand Down
Loading

0 comments on commit da3d5d5

Please sign in to comment.