-
-
Notifications
You must be signed in to change notification settings - Fork 51
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Implement Laplace (quadratic) approximation #345
Conversation
I am looking for the best way to return not just a posterior sample distribution but also the mean vector and covariance matrix of the Gaussian distribution. Any suggestion for this. So far my only idea is to add another section to the inferenceData returned containing this information. Thoughts on this? |
I'm not sure the InferenceData is the best place to put it. We should copy
whatever we do with Variational Inference
…On Sun, 2 Jun 2024, 13:05 Carsten Jørgensen, ***@***.***> wrote:
I am looking for the best way to return not just a posterior sample
distribution but also the mean vector and covariance matrix of the Gaussian
distribution. Any suggestion for this. So far my only idea is to add
another section to the inferenceData returned containing this information.
Thoughts on this?
—
Reply to this email directly, view it on GitHub
<#345 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUM45EIPXABROOJPVCTZFL353AVCNFSM6AAAAABIT3F2F2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBTHAYDGOBQGA>
.
You are receiving this because you are subscribed to this thread.Message
ID: ***@***.***>
|
Historically, Inferencedata has been focused on mcmc. But we have discussed a few times extend it to better handle other inference methods, like SMC or variational methods. It just that there has not been enough momentum to agree and implement and schema that works for those methods. |
@zaxtax and @aloctavodia are you saying that I should not return inferencedata at all or just not return the gaussian mean and covariance in the inferencedata object? I am new to both PYMC and Bayesian statistics so I do not know the history of this package. |
Oh, it's more that we haven't decided how to handle this within the
library. Don't treat this as a blocker, though we should raise it for
discussion more broadly
…On Sun, 2 Jun 2024, 19:04 Carsten Jørgensen, ***@***.***> wrote:
@zaxtax <https://github.com/zaxtax> and @aloctavodia
<https://github.com/aloctavodia> are you saying that I should not return
inferencedata at all or just not return the gaussian mean and covariance in
the inferencedata object? I am new to both PYMC and Bayesian statistics so
I do not know the history of this package.
Best, Carsten
—
Reply to this email directly, view it on GitHub
<#345 (comment)>,
or unsubscribe
<https://github.com/notifications/unsubscribe-auth/AAACCUKGG7KIW4SVLMG6ZLLZFNGB7AVCNFSM6AAAAABIT3F2F2VHI2DSMVQWIX3LMV43OSLTON2WKQ3PNVWWK3TUHMZDCNBTHE2DINBZGU>
.
You are receiving this because you were mentioned.Message ID:
***@***.***>
|
CC @ferrine |
exactly, just saying that if necessary InferenceData can be extended. |
Suggestion, include two groups in the returned inferencedata:
We could even try different fits from distinct initialization points (optionally) and save those as distinct "chains" in the |
@carsten-j PR looks great! I left some comment above |
Thanks you @ricardoV94 and @twiecki for the review comments. I believe that all of them expect one has been fixed. I have not figured out how to use |
The docs contains code example: https://www.pymc.io/projects/docs/en/stable/api/model/generated/pymc.model.transform.conditioning.remove_value_transforms.html |
I should have mentioned that I did read the doc and looked at the example. But I have not been able to figure out how to apply it to my case. I will try again ... |
To be able to use it inside the model context, it will need this change to get merged first: pymc-devs/pymc#7352 But you should be able to already test by doing the object way with |
@ricardoV94 I figured out how to replace the for loop with remove_value_transforms. Is the PR ready for merge or are there additional review comments? |
logsigma = pm.Uniform("logsigma", 1, 100) | ||
mu = pm.Uniform("mu", -10000, 10000) | ||
yobs = pm.Normal("y", mu=mu, sigma=pm.math.exp(logsigma), observed=y) | ||
vars = [mu, logsigma] |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Question: say you only did vars=[mu]
, how would the variable logsigma
be estimated?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I think find_MAP in that case uses the initial_point for the excluded variable(s). I never found that behavior useful tbh
Edit: Maybe it's fine. Either way it's documented here: https://github.com/pymc-devs/pymc/blob/05b557f6460a10c29c3db33690ee535f5b1ecde0/pymc/tuning/starting.py#L73-L75
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Sounds like we may want to pass a similar start
kwarg to laplace to set the value of variables that are not being optimized?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
worth adding a test on this to confirm the behaviour
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I am not sure I fully understand this. I will give it a second go with the documentation for find_MAP.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi Carsten, is there anything we can do to help get this over the line?
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Hi @theorashid. I am not sure how to handle if only a subset of the variables are passed in, e.g. vars=[mu] and log_sigma is left out. If this should raise a warning I need some way of figuring out the number of model parameters and compare that with the number of parameters in vars. I am not sure how to determine the number of model parameters
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
model.free_RVs
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
@theorashid and @ricardoV94, I committed an update that will raise a warning in case number of variables in vars does not equal number of model variables.
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
LGTM
@carsten-j tests are no longer failing in main. You can rebase/merge into your branch |
* Allow forward sampling of statespace models in JAX mode Explicitly set data shape to avoid broadcasting error Better handling of measurement error dims in `SARIMAX` models Freeze auxiliary models before forward sampling Bugfixes for posterior predictive sampling helpers Allow specification of time dimension name when registering data Save info about exogenous data for post-estimation tasks Restore `_exog_data_info` member variable Be more consistent with the names of filter outputs * Adjust test suite to reflect API changes Modify structural tests to accommodate deterministic models Save kalman filter outputs to idata for statespace tests Remove test related to `add_exogenous` Adjust structural module tests * Add JAX test suite * Bug-fixes and changes to statespace distributions Remove tests related to the `add_exogenous` method Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"` Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs Add signature and simple test for `SequenceMvNormal` * Re-run example notebooks * Add helper function to sample prior/posterior statespace matrices * fix tests * Wrap jax MvNormal rewrite in try/except block * Don't use `action` keyword in `catch_warnings` * Skip JAX test if `numpyro` is not installed * Handle batch dims on `SequenceMvNormal` * Remove unused batch_dim logic in SequenceMvNormal * Restore `get_support_shape_1d` import
Check out this pull request on See visual diffs & provide feedback on Jupyter Notebooks. Powered by ReviewNB |
@ricardoV94 I have rebased the laplace branch but it looks like someone needs to approve Github worksflows. |
Looks like there are still a few failing tests, but once those pass this is probably good to merge |
@zaxtax failing test has been fixed. Can you approve the waiting workflow? |
@zaxtax, all tests passed. Are you also able to merge the PR? Thanks. |
Congrats @carsten-j, this is a big one! |
Thank you @twiecki. Really happy to contribute and thanks to all those that helped. After the summer I will try to work on documentation for building and running locally. I took me some time to figure out how this works! |
Congrats @carsten-j this is really neat! |
Brilliant work @carsten-j . Hope to see you contribute to PyMC again! |
This is an early version of q quadratic approximation implementation that I have developed while reading Statistical Rethinking by Richard McElreath.
There is a short discussion about this in the issue and maybe @theorashid can help with feedback of this draft PR.
This work is partly based on the Python package pymc3-quap but pymc3-quap is based on PYMC3 and a lot happend bewteen version 3 and 5 of PYMC. Optimizers works better when provided with a good initial guess and hence a (optional) starting point has been added to function arguments. Please see Github for a discussion about the differences between PYMC version 3 and 5 for computing the Hessian.