Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added more tests for linearize and make_empirical_fisher_vp #405

Merged
merged 16 commits into from
Dec 6, 2023

Conversation

agrawalraj
Copy link
Contributor

@agrawalraj agrawalraj commented Nov 29, 2023

Closes #386

  • robust_fixtures.py contains utilities to remove repeated code across testing files (just as we did for the tests in dynamical)
  • test_linearize_against_analytic_ate compares our empirical efficient influence function for average treatment effect against the analytical formula from Kennedy
  • test_empirical_fisher_vp_against_analytical compares the empirical fisher vector product against the analytical fisher vector product for the special multivariate Gaussian case
  • test_fisher_vmap_smoke check composition with torch.func.vmap
  • test_fisher_grad_smoke) check composition with torch.func.grad

@agrawalraj agrawalraj added the status:awaiting review Awaiting response from reviewer label Nov 30, 2023
@SamWitty SamWitty added status:WIP Work-in-progress not yet ready for review and removed status:awaiting review Awaiting response from reviewer labels Nov 30, 2023
@SamWitty
Copy link
Collaborator

@agrawalraj , I removed the "status:awaiting review" tag because it looked like you were still working on this. Feel free to change it back if you want.

@agrawalraj
Copy link
Contributor Author

agrawalraj commented Dec 1, 2023

@agrawalraj , I removed the "status:awaiting review" tag because it looked like you were still working on this. Feel free to change it back if you want.

Thanks yeah I realized the tests weren't passing so made some small edits to get them to pass now. Adding a few more tests so going to mark to WIP.

@agrawalraj agrawalraj added status:awaiting review Awaiting response from reviewer status:WIP Work-in-progress not yet ready for review and removed status:WIP Work-in-progress not yet ready for review status:awaiting review Awaiting response from reviewer labels Dec 1, 2023
@eb8680
Copy link
Contributor

eb8680 commented Dec 6, 2023

@agrawalraj I see #408 is marked as ready for review, does that mean this is ready as well?

@agrawalraj
Copy link
Contributor Author

agrawalraj commented Dec 6, 2023

@agrawalraj I see #408 is marked as ready for review, does that mean this is ready as well?

The failing test here will be addressed by #408. Let me fix the linting error now.

Update: @eb8680 this is ready for review

@eb8680 eb8680 added status:awaiting review Awaiting response from reviewer and removed status:WIP Work-in-progress not yet ready for review labels Dec 6, 2023
pyro.settings.set(module_local_params=True)


def test_empirical_fisher_vp_nmclikelihood_cg_composition():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

If this is the test you expect to fail without #408, can you mark it with pytest.mark.xfail(reason="fails without fix in https://github.com/BasisResearch/chirho/pull/408")?

Copy link
Contributor Author

@agrawalraj agrawalraj Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Given that the change in #408 now just uses SeedMessenger in the linearize body, I can fix these errors by doing the same in this test. I'll push up these changes now.

Copy link
Contributor Author

@agrawalraj agrawalraj Dec 6, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@eb8680 just made this change!

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine, but remember to remove the SeedMessenger here in #408, since you want this test to exercise the changes made to NMCLogPredictiveLikelihood there.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ok in #408, I followed your suggestion here #408 (review), and did not actually modify NMCLogPredictiveLikelihood. I just checked if

log_prob = NMCLogPredictiveLikelihood(
        model, guide, num_samples=3, max_plate_nesting=3
    )
log_prob_params, func_log_prob = make_functional_call(log_prob)
func_log_prob = SeedMessenger(123)(func_log_prob)

works in the test_nmc_likelihood_seeded test in #408. I think we can move this discussion to #408 if you have other suggestions on this point!

Copy link
Contributor

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Looks good, just a few comments.

tests/robust/test_internals_linearize.py Outdated Show resolved Hide resolved
tests/robust/test_internals_linearize.py Outdated Show resolved Hide resolved
tests/robust/test_internals_linearize.py Outdated Show resolved Hide resolved
tests/robust/test_internals_linearize.py Outdated Show resolved Hide resolved
tests/robust/robust_fixtures.py Outdated Show resolved Hide resolved
tests/robust/robust_fixtures.py Outdated Show resolved Hide resolved
pyro.settings.set(module_local_params=True)


def test_empirical_fisher_vp_nmclikelihood_cg_composition():
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That's fine, but remember to remove the SeedMessenger here in #408, since you want this test to exercise the changes made to NMCLogPredictiveLikelihood there.

tests/robust/robust_fixtures.py Show resolved Hide resolved
@eb8680 eb8680 added status:awaiting response Awaiting response from creator and removed status:awaiting review Awaiting response from reviewer labels Dec 6, 2023
@agrawalraj agrawalraj added status:awaiting review Awaiting response from reviewer and removed status:awaiting response Awaiting response from creator labels Dec 6, 2023
Copy link
Contributor

@eb8680 eb8680 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM

@eb8680 eb8680 merged commit c5fe64b into staging-robust Dec 6, 2023
6 checks passed
@eb8680 eb8680 deleted the ra-sw-fisher-tests branch December 6, 2023 20:55
eb8680 added a commit that referenced this pull request Jan 12, 2024
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
agrawalraj added a commit that referenced this pull request Jan 19, 2024
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

* add abstractions and simple temp scratch to test with squared unit normal functional with perturbation.

* removes old scratch notebook

* gets squared density running under abstraction that couples functionals and models

* gets quad and mc approximations to match, vectorization hacky.

* adds plotting and comparative to analytic.

* adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas

* fixes dataset splitting, breaks analytic eif

* unfixes an incorrect fix, working now.

* refactors finite difference machinery to fit experimental specs.

* switches to existing rng seed context manager.

* reverts back to what turns out to be a slightly different seeding context.

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Sam Witty <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
agrawalraj added a commit that referenced this pull request Jan 24, 2024
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

* add abstractions and simple temp scratch to test with squared unit normal functional with perturbation.

* removes old scratch notebook

* gets squared density running under abstraction that couples functionals and models

* gets quad and mc approximations to match, vectorization hacky.

* adds plotting and comparative to analytic.

* adds scratch experiment comparing squared density analytic vs fd approx across various epsilon lambdas

* fixes dataset splitting, breaks analytic eif

* unfixes an incorrect fix, working now.

* refactors finite difference machinery to fit experimental specs.

* switches to existing rng seed context manager.

* reverts back to what turns out to be a slightly different seeding context.

* gets fd integrated into experiment exec and running.

* adds perturbable normal model to statics listing

* switches back to mean not mu

* lines up mean mu loc naming correctly.

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Sam Witty <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
agrawalraj added a commit that referenced this pull request Jan 25, 2024
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* progress on tmle

* placeholder test

* more progress on TMLE

* more progress, still need to refactor

* progress on variational tmle

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* more progress on tmle

* really resolved merge conflicts

* more progress, still a bit stuck on functional tensors

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* update tmle signature and remove unused imports

* make tmle signature consistent with one-step

* lint

* progress

* pair program still issues

* debugging still

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

* more attempts

* added scipy optimize :(

* more progress

* more progress

* working end-to-end tmle

* remove comment

* revert changes

* update tests and defaults

* lint

* playing with tmle performance

* more tweaks

* pulled out influence computation and changed loss

* finally got tmle working

* revert test

* added placeholder for passing in influence_fn_estimator

* analytic influence for example

* lint

* fix estimator

* fix tests

* lint

* notebook

* bump notebook

* use torchopt

* add torchopt

* rerun tmle notebook with effect = 1

* lint

---------

Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: Eli <[email protected]>
Co-authored-by: Raj Agrawal <[email protected]>
Co-authored-by: eb8680 <[email protected]>
SamWitty added a commit that referenced this pull request Jul 18, 2024
* added robust folder

* uncommited scratch work for log prob

* untested variational log prob

* uncomitted changes

* uncomitted changes

* pair coding w/ eli

* added tests w/ Eli

* eif

* linting

* moving test autograd to internals and deleted old utils file

* sketch influence implementation

* fix more args

* ops file

* file

* format

* lint

* clean up influence and tests

* make tests more generic

* guess max plate nesting

* linearize

* rename file

* tensor flatten

* predictive eif

* jvp type

* reorganize files

* shrink test case

* move guess_max_plate_nesting

* move cg solver to linearze

* type alias

* test_ops

* basic cg tests

* remove failing test case

* format

* move paramdict up

* remove obsolete test files

* add empty handlers

* add chirho.robust to docs

* fix memory leak in tests

* make typing compatible with python 3.8

* typing_extensions

* add branch to ci

* predictive

* remove imprecise annotation

* Added more tests for `linearize` and `make_empirical_fisher_vp` (#405)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* removed missing import

* fixed failing test with seeding

* addressing Eli's comments

* Add upper bound on number of CG steps (#404)

* upper bound on cg_iters

* address comment

* fixed test for non-symmetric matrix (#437)

* Make `NMCLogPredictiveLikelihood` seeded (#408)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* switched back to different

* Use Hessian formulation of Fisher information in `make_empirical_fisher_vp` (#430)

* hessian vector product formulation for fisher

* ignoring small type error

* fixed linting error

* Add new `SimpleModel` and `SimpleGuide` (#440)

* initial test against analytic fisher vp (pair coded w/ sam)

* linting

* added check against analytic ate

* added vmap and grad smoke tests

* added missing init

* linting and consolidated fisher tests to one file

* fixed types

* fixing linting errors

* trying to fix type error for python 3.8

* fixing test errors

* added patch to test to prevent from failing when denom is small

* composition issue

* seeded NMC implementation

* linting

* removed missing import

* changed to eli's seedmessenger suggestion

* added failing edge case

* explicitly add max plate argument

* added warning message

* fixed linting error and test failure case from too many cg iters

* eli's contextlib seeding strategy

* removed seedmessenger from test

* randomness should be shared across calls

* uncomitted change before branch switch

* switched back to different

* added revised simple model and guide

* added multiple link functions in test

* linting

* Batching in `linearize` and `influence` (#465)

* batching in linearize and influence

* addressing eli's review

* added optimization for pointwise false case

* fixing lint error

* batched cg (#466)

* One step correction implemented (#467)

* one step correction

* increased tolerance

* fixing lint issue

* Replace some `torch.vmap` usage with a hand-vectorized `BatchedNMCLogPredictiveLikelihood` (#473)

* sketch batched nmc lpd

* nits

* fix type

* format

* comment

* comment

* comment

* typo

* typo

* add condition to help guarantee idempotence

* simplify edge case

* simplify plate_name

* simplify batchedobservation logic

* factorize

* simplify batched

* reorder

* comment

* remove plate_names

* types

* formatting and type

* move unbind to utils

* remove max_plate_nesting arg from get_traces

* comment

* nit

* move get_importance_traces to utils

* fix types

* generic obs type

* lint

* format

* handle observe in batchedobservations

* event dim

* move batching handlers to utils

* replace 2/3 vmaps, tests pass

* remove dead code

* format

* name args

* lint

* shuffle code

* try an extra optimization in batchedlatents

* add another optimization

* undo changes to test

* remove inplace adds

* add performance test showing speedup

* document internal helpers

* batch latents test

* move batch handlers to predictive

* add bind_leftmost_dim, document PredictiveFunctional and PredictiveModel

* use bind_leftmost_dim in log prob

* Added documentation for `chirho.robust` (#470)

* documentation

* documentation clean up w/ eli

* fix lint issue

* old dr notebook that got deleted from wrong merge

* added missing fig

* redid notebook with new interface

* Make functional argument to influence_fn required (#487)

* Make functional argument required

* estimator

* docstring

* Remove guide argument from `influence_fn` and `linearize` (#489)

* Make functional argument required

* estimator

* docstring

* Remove guide, make tests pass

* rename internals.predictive to internals.nmc

* expose handlers.predictive

* expose handlers.predictive

* docstrings

* fix doc build

* fix equation

* docstring import

---------

Co-authored-by: Sam Witty <[email protected]>

* updated labels

* updated w/ new interface but only 1 data sim

* Make influence_fn a higher-order Functional (#492)

* make influence a functional

* fix test

* multiple arguments

* doc

* docstring

* docstring

* uncommitted changes

* Add full corrected one step estimator (#476)

* added scaffolding to one step estimator

* kept signature the same as one_step_correction

* lint

* refactored test to include multiple estimators

* typo

* revise error

* added dict handling

* remove assert

* more informative error message

* replace dispatch with pytree flatten and unflatten

* revert arg for influence_function_estimator

* docs and lint

* lingering influence_fn

* fixed missing return

* rename

* lint

* add *model to appease the linter

* kernel speedup

* before switching to krr formulation

* uncommitted changes

* updated w/ new interface; removed GP section for now

* runs but not matching

* still not working, going to make major changes

* remove debug script

* remove file

* remove file

* add

* update interfaces

* finished running

* outline

* remove outline for now

* simplify notebook

* merge

---------

Co-authored-by: Eli <[email protected]>
Co-authored-by: Sam Witty <[email protected]>
Co-authored-by: eb8680 <[email protected]>
Co-authored-by: Eli <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants