-
Notifications
You must be signed in to change notification settings - Fork 12
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Add upper bound on number of CG steps #404
Conversation
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think we don't want to do this here because then we can't vmap over _fn
in linearize
. I think we should move this check to linearize
@agrawalraj this is ready for review |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Looks good!
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <[email protected]> * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter --------- Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: Eli <[email protected]> Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: eb8680 <[email protected]>
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <[email protected]> * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter * add abstractions and simple temp scratch to test with squared unit normal functional with perturbation. * removes old scratch notebook * gets squared density running under abstraction that couples functionals and models * gets quad and mc approximations to match, vectorization hacky. * adds plotting and comparative to analytic. * adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas * fixes dataset splitting, breaks analytic eif * unfixes an incorrect fix, working now. * refactors finite difference machinery to fit experimental specs. * switches to existing rng seed context manager. * reverts back to what turns out to be a slightly different seeding context. --------- Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: Eli <[email protected]> Co-authored-by: Sam Witty <[email protected]> Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: eb8680 <[email protected]>
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <[email protected]> * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter * add abstractions and simple temp scratch to test with squared unit normal functional with perturbation. * removes old scratch notebook * gets squared density running under abstraction that couples functionals and models * gets quad and mc approximations to match, vectorization hacky. * adds plotting and comparative to analytic. * adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas * fixes dataset splitting, breaks analytic eif * unfixes an incorrect fix, working now. * refactors finite difference machinery to fit experimental specs. * switches to existing rng seed context manager. * reverts back to what turns out to be a slightly different seeding context. * gets fd integrated into experiment exec and running. * adds perturbable normal model to statics listing * switches back to mean not mu * lines up mean mu loc naming correctly. --------- Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: Eli <[email protected]> Co-authored-by: Sam Witty <[email protected]> Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: eb8680 <[email protected]>
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * progress on tmle * placeholder test * more progress on TMLE * more progress, still need to refactor * progress on variational tmle * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <[email protected]> * more progress on tmle * really resolved merge conflicts * more progress, still a bit stuck on functional tensors * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * update tmle signature and remove unused imports * make tmle signature consistent with one-step * lint * progress * pair program still issues * debugging still * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter * more attempts * added scipy optimize :( * more progress * more progress * working end-to-end tmle * remove comment * revert changes * update tests and defaults * lint * playing with tmle performance * more tweaks * pulled out influence computation and changed loss * finally got tmle working * revert test * added placeholder for passing in influence_fn_estimator * analytic influence for example * lint * fix estimator * fix tests * lint * notebook * bump notebook * use torchopt * add torchopt * rerun tmle notebook with effect = 1 * lint --------- Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: Eli <[email protected]> Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: eb8680 <[email protected]>
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * old dr notebook that got deleted from wrong merge * added missing fig * redid notebook with new interface * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <[email protected]> * updated labels * updated w/ new interface but only 1 data sim * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * uncommitted changes * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter * kernel speedup * before switching to krr formulation * uncommitted changes * updated w/ new interface; removed GP section for now * runs but not matching * still not working, going to make major changes * remove debug script * remove file * remove file * add * update interfaces * finished running * outline * remove outline for now * simplify notebook * merge --------- Co-authored-by: Eli <[email protected]> Co-authored-by: Sam Witty <[email protected]> Co-authored-by: eb8680 <[email protected]> Co-authored-by: Eli <[email protected]>
Addresses @agrawalraj's comments in #398: