Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Add find_MAP with close JAX integration and fix bug with Laplace fit #385

Merged
merged 21 commits into from
Dec 4, 2024

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Oct 27, 2024

Closes #376

This PR adds code to run find_MAP using JAX. I'm using JAX for gradients, because I found the compile times were faster. Open to suggestions/rebuke.

It also adds a fit_laplace function, which is bad because we already have a fit_laplace function. This one has slightly different objective though -- it isn't meant to be used as a step sampler on a subset of model variables. Instead, it is meant to be used on the MAP result to give an approximation to the full posterior. My function also lets you do the Laplace approximation in the transformed space, then do sample-wise reverse transformation. I think this is legit, and lets you obtain approximate posteriors that respect the domain of the prior. Tagging @theorashid so we can resolve the differences.

Last point is that I added a dependency on better_optimize. This is a package I wrote that basically rips out the wrapper code used in PyMC find_MAP and applies it to arbitrary optimization problems. It is more feature complete than the PyMC wrapper -- it supports all optimizer modes for scipy.optimize.minimize and scipy.optimize.root, and also helps get keywords to the right place in those functions (who can ever remember if an argument goes in method_kwargs or in the funciton itself?). I plan to add support for basinhopping as well, which will be nice for really hairy minimizations.

I could see an objection to adding another dependency, but 1) it's a lightweight wrapper around functionality that doesn't really belong in PyMC anyway, and 2) it's a big value-add compared to working directly with the scipy.optimize functions, which have gnarly, inconsistent signatures.

@theorashid
Copy link
Contributor

Hey, nice one, yeah I agree, we should only have one fit_laplace function.

it isn't meant to be used as a step sampler on a subset of model variables

The current fit_laplace isn't either. It isn't a step sampler. (The INLA stuff #340 still has a few blockers so that's separate and not yet in the library.) The implementation was made by a user following Statistical Rethinking, where McElreath fits some models using the Laplace approximation of all parameters.

Current behaviour of fit_laplace is:

The behaviour when you only pass a subset of variables isn't really desirable in my opinion (see #345 (comment)), so we put a warning. So as you say:

Instead, it is meant to be used on the MAP result to give an approximation to the full posterior.

Agree, that's the best plan for fit_laplace.

Judging by your docs and a quick glance at your code, I think you're basically doing the same thing. The current implementation is few lines of code and a few docs, so I reckon

  1. make sure you can pass the test case with your method, which is an example from BDA3 https://github.com/pymc-devs/pymc-experimental/blob/main/tests/test_laplace.py
  2. throw any of the useful code and docs into your method

Then it should be safe to delete the existing code and we can go back to one fit_laplace.

I could see an objection to adding another dependency

I would love a generic optimiser in p u r e pytensor, but I can see looking at your code that there a lot of fancy extras that would take a large effort to write in pytensor. Still, if we want to go back to one of our efforts with a fixed point operator (pymc-devs/pytensor#978 and pymc-devs/pytensor#944), we could probably write find_MAP with that in some form, with fewer bells and whistles though.

Happy to look at your code and review properly later in the week if you'd like me to. Let me know. Otherwise, I'll leave to the core devs.

@ricardoV94
Copy link
Member

Happy to look at your code and review properly later in the week if you'd like me to. Let me know. Otherwise, I'll leave to the core devs.

That would be appreciated

@ricardoV94
Copy link
Member

Agree with what @theorashid said. This fit_laplace is going for the same goal as the previous one. Happy to replace it, if it's not married to JAX backend. Still fine to allow using JAX for the autodiff. What you're offering is very similar to nutpie gradient_backend kwarg, so we could use the same terminology

@ricardoV94
Copy link
Member

No objections about your custom library wrapper

@jessegrabowski
Copy link
Member Author

tagging @theorashid -- I couldn't pick you as a reviewer?

I did a major refactor of this. I broke the marriage to jax and generalized the find_MAP function. Files have been renamed to reflect this. I also merged the two laplace approaches. The biggest change is that I removed the ability to choose vars. I think the idea here was to be able to partially marginalize some variables in a model? But I think this would require a somewhat different approach.

@theorashid
Copy link
Contributor

yea sorry I'm just a normal, but I'll give it a review. Will do it at some point in the next 2 weeks.

Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Minor suggestions, PR looks amazing!

@ricardoV94
Copy link
Member

ricardoV94 commented Dec 4, 2024

@jessegrabowski can we close #376 with this PR?

Do you have a test that covers something like it?

@jessegrabowski
Copy link
Member Author

Yes, I think this test and this test should cover that issue

@ricardoV94 ricardoV94 added the enhancements New feature or request label Dec 4, 2024
@ricardoV94 ricardoV94 changed the title Add JAX-based find_MAP function Add find_MAP with close JAX integration and fix bug with Laplace fit Dec 4, 2024
@jessegrabowski jessegrabowski merged commit 5055262 into pymc-devs:main Dec 4, 2024
7 checks passed
@theorashid
Copy link
Contributor

sweet, all done?

@jessegrabowski
Copy link
Member Author

For now, though I'd still appreciate it if you could have a look and open issues on any bugs/shortcomings you find

Copy link
Contributor

@theorashid theorashid left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I managed to follow the code through and it looks good to me. Happy you got rid of the option to fit on a subset of variables, which didn't make sense to me anyway. If it passes the original test then it should be good. You can do something about the other comments if you want, but maybe not because we are e x p e r i m e n t a l

return f_loss_and_grad, f_hess, f_hessp


def _compile_functions(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

NIT: Maybe _compile_functions and _compile_jax_gradients are slightly too generic function names. I found it a little tricky to remember exactly what they were doing when reading through the code

use_hess = use_hess if use_hess is not None else method_info["uses_hess"]
use_hessp = use_hessp if use_hessp is not None else method_info["uses_hessp"]

if use_hess and use_hessp:
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I was going through all the methods thinking when you would need hess and hessp and then came back to this. I would probably warn the user / not let them pass both use_hess and use_hessp

return idata


def fit_mvn_to_MAP(
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

fit_mvn_at_MAP? I mean technically this function just fits a MVN at a point, the user doesn't necessarily have to pass the MAP

H_inv = get_nearest_psd(H_inv)
if on_bad_cov == "warn":
_log.warning(
"Inverse Hessian is not positive semi-definite at the provided point, using the closest PSD "
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For my understanding, what sort of scenarios/models would get a not PSD hessian. And is using closest PSD a good ideas?


Parameters
----------
mu
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

add docs here

and 1).

.. warning::
This argumnet should be considered highly experimental. It has not been verified if this method produces
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

*argument

gradient_backend: str, default "pytensor"
The backend to use for gradient computations. Must be one of "pytensor" or "jax".
chains: int, default: 2
The number of sampling chains running in parallel.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'd probably add something here reiterating that this isn't a sampling inference method. This is just sampling from the approximated posterior. There was already people in the forum asking about the differences in these methods



@pytest.mark.parametrize(
"method, use_grad, use_hess",
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

any use_hessp tests? or are we just testing if scipy works here

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
enhancements New feature or request
Projects
None yet
Development

Successfully merging this pull request may close these issues.

Laplace approximation not handling non-scalar parameters
3 participants