Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Implement TMLE estimator using Influence Functions (#484)
* added robust folder * uncommited scratch work for log prob * untested variational log prob * uncomitted changes * uncomitted changes * pair coding w/ eli * added tests w/ Eli * eif * linting * moving test autograd to internals and deleted old utils file * sketch influence implementation * fix more args * ops file * file * format * lint * clean up influence and tests * make tests more generic * guess max plate nesting * linearize * rename file * tensor flatten * predictive eif * jvp type * reorganize files * shrink test case * move guess_max_plate_nesting * move cg solver to linearze * type alias * test_ops * basic cg tests * remove failing test case * format * move paramdict up * remove obsolete test files * add empty handlers * add chirho.robust to docs * fix memory leak in tests * make typing compatible with python 3.8 * typing_extensions * add branch to ci * predictive * remove imprecise annotation * Added more tests for `linearize` and `make_empirical_fisher_vp` (#405) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * removed missing import * fixed failing test with seeding * addressing Eli's comments * Add upper bound on number of CG steps (#404) * upper bound on cg_iters * address comment * fixed test for non-symmetric matrix (#437) * Make `NMCLogPredictiveLikelihood` seeded (#408) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * switched back to different * Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430) * hessian vector product formulation for fisher * ignoring small type error * fixed linting error * Add new `SimpleModel` and `SimpleGuide` (#440) * initial test against analytic fisher vp (pair coded w/ sam) * linting * added check against analytic ate * added vmap and grad smoke tests * added missing init * linting and consolidated fisher tests to one file * fixed types * fixing linting errors * trying to fix type error for python 3.8 * fixing test errors * added patch to test to prevent from failing when denom is small * composition issue * seeded NMC implementation * linting * removed missing import * changed to eli's seedmessenger suggestion * added failing edge case * explicitly add max plate argument * added warning message * fixed linting error and test failure case from too many cg iters * eli's contextlib seeding strategy * removed seedmessenger from test * randomness should be shared across calls * uncomitted change before branch switch * switched back to different * added revised simple model and guide * added multiple link functions in test * linting * Batching in `linearize` and `influence` (#465) * batching in linearize and influence * addressing eli's review * added optimization for pointwise false case * fixing lint error * batched cg (#466) * One step correction implemented (#467) * one step correction * increased tolerance * fixing lint issue * Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473) * sketch batched nmc lpd * nits * fix type * format * comment * comment * comment * typo * typo * add condition to help guarantee idempotence * simplify edge case * simplify plate_name * simplify batchedobservation logic * factorize * simplify batched * reorder * comment * remove plate_names * types * formatting and type * move unbind to utils * remove max_plate_nesting arg from get_traces * comment * nit * move get_importance_traces to utils * fix types * generic obs type * lint * format * handle observe in batchedobservations * event dim * move batching handlers to utils * replace 2/3 vmaps, tests pass * remove dead code * format * name args * lint * shuffle code * try an extra optimization in batchedlatents * add another optimization * undo changes to test * remove inplace adds * add performance test showing speedup * document internal helpers * batch latents test * move batch handlers to predictive * add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel * use bind_leftmost_dim in log prob * Added documentation for `chirho.robust` (#470) * documentation * documentation clean up w/ eli * fix lint issue * progress on tmle * placeholder test * more progress on TMLE * more progress, still need to refactor * progress on variational tmle * Make functional argument to influence_fn required (#487) * Make functional argument required * estimator * docstring * Remove guide argument from `influence_fn` and `linearize` (#489) * Make functional argument required * estimator * docstring * Remove guide, make tests pass * rename internals.predictive to internals.nmc * expose handlers.predictive * expose handlers.predictive * docstrings * fix doc build * fix equation * docstring import --------- Co-authored-by: Sam Witty <[email protected]> * more progress on tmle * really resolved merge conflicts * more progress, still a bit stuck on functional tensors * Make influence_fn a higher-order Functional (#492) * make influence a functional * fix test * multiple arguments * doc * docstring * docstring * update tmle signature and remove unused imports * make tmle signature consistent with one-step * lint * progress * pair program still issues * debugging still * Add full corrected one step estimator (#476) * added scaffolding to one step estimator * kept signature the same as one_step_correction * lint * refactored test to include multiple estimators * typo * revise error * added dict handling * remove assert * more informative error message * replace dispatch with pytree flatten and unflatten * revert arg for influence_function_estimator * docs and lint * lingering influence_fn * fixed missing return * rename * lint * add *model to appease the linter * more attempts * added scipy optimize :( * more progress * more progress * working end-to-end tmle * remove comment * revert changes * update tests and defaults * lint * playing with tmle performance * more tweaks * pulled out influence computation and changed loss * finally got tmle working * revert test * added placeholder for passing in influence_fn_estimator * analytic influence for example * lint * fix estimator * fix tests * lint * notebook * bump notebook * use torchopt * add torchopt * rerun tmle notebook with effect = 1 * lint --------- Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: Eli <[email protected]> Co-authored-by: Raj Agrawal <[email protected]> Co-authored-by: eb8680 <[email protected]>
- Loading branch information