-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Batching in linearize
and influence
#465
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good, just a couple of questions and comments
pinned_fvp = reset_rng_state(pyro.util.get_rng_state())(fvp) | ||
point_score: ParamDict = score_fn(log_prob_params, point, *args, **kwargs) | ||
return cg_solver(pinned_fvp, point_score) | ||
batched_func_log_prob = torch.vmap( |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
score_fn
is already defined above. Can we simplify this to vmap
over score_fn
, or was there some performance advantage in applying vmap
inside of grad
/jacrev
/jacfwd
?
batch_score_fn = torch.func.vmap(
lambda p, data: score_fn(p, data, *args, **kwargs),
in_dims=(None, 0),
randomness="different",
)
point_scores = batch_score_fn(log_prob_params, points)
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Oops I forgot to remove the line score_fn = torch.func.grad(func_log_prob)
, thanks for the catch! I don't think we want to vmap
over score_fn
because once #463 is addressed, func_log_prob
will return the log probability of multiple datapoints at once. So we would need to use jacrev
or jacfwd
since the output of func_log_prob
is a vector. @eb8680
chirho/robust/internals/linearize.py
Outdated
if log_prob_params_numel > points[next(iter(points))].shape[0]: | ||
score_fn = torch.func.jacrev(batched_func_log_prob) | ||
else: | ||
score_fn = torch.func.jacfwd(batched_func_log_prob, randomness="different") |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Did you see a performance improvement with jacfwd
? I would naively expect it to be worse than jacrev
here since even though the output of score_fn
of batched_func_log_prob
may be higher-dimensional, its components are completely independent so there's no opportunity to share common subexpressions.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I just did this assuming that jacfwd
would be fast if output space was larger than the input space but I think you make a good point. If the number of points in points
is extremely large, then pointwise_influence
should really be set to False
. In the latest commit, I made the pointwise_influence
much more efficient using torch.func.vjp
to avoid creating a huge Jacobian matrix.
chirho/robust/ops.py
Outdated
)[1] | ||
def _fn( | ||
points: Point[T], | ||
pointwise_influence: bool = False, |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Can we move the new pointwise_influence
keyword argument up to be a keyword argument of linearize
and influence_fn
instead of changing the signature of the output functions?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Yup I can do that!
@eb8680 just addressed all points, thanks! |
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <[email protected]> * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter --------- Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: Eli <[email protected]> Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: eb8680 <[email protected]>
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <[email protected]> * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter * add abstractions and simple temp scratch to test with squared unit normal functional with perturbation. * removes old scratch notebook * gets squared density running under abstraction that couples functionals and models * gets quad and mc approximations to match, vectorization hacky. * adds plotting and comparative to analytic. * adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas * fixes dataset splitting, breaks analytic eif * unfixes an incorrect fix, working now. * refactors finite difference machinery to fit experimental specs. * switches to existing rng seed context manager. * reverts back to what turns out to be a slightly different seeding context. --------- Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: Eli <[email protected]> Co-authored-by: Sam Witty <[email protected]> Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: eb8680 <[email protected]>
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <[email protected]> * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter * add abstractions and simple temp scratch to test with squared unit normal functional with perturbation. * removes old scratch notebook * gets squared density running under abstraction that couples functionals and models * gets quad and mc approximations to match, vectorization hacky. * adds plotting and comparative to analytic. * adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas * fixes dataset splitting, breaks analytic eif * unfixes an incorrect fix, working now. * refactors finite difference machinery to fit experimental specs. * switches to existing rng seed context manager. * reverts back to what turns out to be a slightly different seeding context. * gets fd integrated into experiment exec and running. * adds perturbable normal model to statics listing * switches back to mean not mu * lines up mean mu loc naming correctly. --------- Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: Eli <[email protected]> Co-authored-by: Sam Witty <[email protected]> Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: eb8680 <[email protected]>
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * progress on tmle * placeholder test * more progress on TMLE * more progress, still need to refactor * progress on variational tmle * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <[email protected]> * more progress on tmle * really resolved merge conflicts * more progress, still a bit stuck on functional tensors * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * update tmle signature and remove unused imports * make tmle signature consistent with one-step * lint * progress * pair program still issues * debugging still * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter * more attempts * added scipy optimize :( * more progress * more progress * working end-to-end tmle * remove comment * revert changes * update tests and defaults * lint * playing with tmle performance * more tweaks * pulled out influence computation and changed loss * finally got tmle working * revert test * added placeholder for passing in influence_fn_estimator * analytic influence for example * lint * fix estimator * fix tests * lint * notebook * bump notebook * use torchopt * add torchopt * rerun tmle notebook with effect = 1 * lint --------- Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: Eli <[email protected]> Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: eb8680 <[email protected]>
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * old dr notebook that got deleted from wrong merge * added missing fig * redid notebook with new interface * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <[email protected]> * updated labels * updated w/ new interface but only 1 data sim * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * uncommitted changes * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter * kernel speedup * before switching to krr formulation * uncommitted changes * updated w/ new interface; removed GP section for now * runs but not matching * still not working, going to make major changes * remove debug script * remove file * remove file * add * update interfaces * finished running * outline * remove outline for now * simplify notebook * merge --------- Co-authored-by: Eli <[email protected]> Co-authored-by: Sam Witty <[email protected]> Co-authored-by: eb8680 <[email protected]> Co-authored-by: Eli <[email protected]>
Closes #464