Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

More bugfixes for statespace #346

Merged
merged 13 commits into from
Jun 29, 2024

Conversation

jessegrabowski
Copy link
Member

@jessegrabowski jessegrabowski commented Jun 3, 2024

The statespace model is still broken, this PR is another round of bugfixes.

  • Update structural notebook to conform to API changes after Bug fixes for statespace #326
  • Remove the numpy helper functions from the structural notebook, use pytensor only
  • Fix JAX-based forward sampling
  • Re-run all notebooks without errors
  • Get all current tests to pass
  • Add JAX tests

I need some help with fixing the JAX forward sampling. I'm doing something wrong, because even after freezing I have dynamic shape errors. This is the major blocker to considering the module "working" again.

Copy link

Check out this pull request on  ReviewNB

See visual diffs & provide feedback on Jupyter Notebooks.


Powered by ReviewNB

@jessegrabowski
Copy link
Member Author

This is ready for review. I will need to make changes again after the next pymc/pytensor releases, but for now this all works -- tests pass, and all notebooks run.

For review I tried to organize the commits into chunks. It would be good to have eyeballs on the changes related to the distributions, since those were creating trouble for me in the first place.

@jessegrabowski
Copy link
Member Author

No idea what's going on with the CI, there seems to be a lot of broken stuff unrelated to this.

Explicitly set data shape to avoid broadcasting error

Better handling of measurement error dims in `SARIMAX` models

Freeze auxiliary models before forward sampling

Bugfixes for posterior predictive sampling helpers

Allow specification of time dimension name when registering data

Save info about exogenous data for post-estimation tasks

Restore `_exog_data_info` member variable

Be more consistent with the names of filter outputs
Modify structural tests to accommodate deterministic models

Save kalman filter outputs to idata for statespace tests

Remove test related to `add_exogenous`

Adjust structural module tests
Remove tests related to the `add_exogenous` method

Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"`

Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs

Add signature and simple test for `SequenceMvNormal`
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Approving so you can merge, let me know if you want me to look at some specific changes carefully

@jessegrabowski jessegrabowski merged commit e85677b into pymc-devs:main Jun 29, 2024
8 checks passed
carsten-j pushed a commit to carsten-j/pymc-experimental that referenced this pull request Jun 30, 2024
* Allow forward sampling of statespace models in JAX mode

Explicitly set data shape to avoid broadcasting error

Better handling of measurement error dims in `SARIMAX` models

Freeze auxiliary models before forward sampling

Bugfixes for posterior predictive sampling helpers

Allow specification of time dimension name when registering data

Save info about exogenous data for post-estimation tasks

Restore `_exog_data_info` member variable

Be more consistent with the names of filter outputs

* Adjust test suite to reflect API changes

Modify structural tests to accommodate deterministic models

Save kalman filter outputs to idata for statespace tests

Remove test related to `add_exogenous`

Adjust structural module tests

* Add JAX test suite

* Bug-fixes and changes to statespace distributions

Remove tests related to the `add_exogenous` method

Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"`

Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs

Add signature and simple test for `SequenceMvNormal`

* Re-run example notebooks

* Add helper function to sample prior/posterior statespace matrices

* fix tests

* Wrap jax MvNormal rewrite in try/except block

* Don't use `action` keyword in `catch_warnings`

* Skip JAX test if `numpyro` is not installed

* Handle batch dims on `SequenceMvNormal`

* Remove unused batch_dim logic in SequenceMvNormal

* Restore `get_support_shape_1d` import
twiecki pushed a commit that referenced this pull request Jul 1, 2024
* First draft of quadratic approximation

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Review comments incorporated

* License and copyright information added

* Only add additional data to inferencedata when chains!=0

* Raise error if Hessian is singular

* Replace for loop with call to remove_value_transforms

* Pass model directly when finding MAP and the Hessian

* Update pymc_experimental/inference/laplace.py

Co-authored-by: Ricardo Vieira <[email protected]>

* Remove chains from public parameters for Laplace approx method

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Parameter draws is not optional with default value 1000

* [pre-commit.ci] auto fixes from pre-commit.com hooks

for more information, see https://pre-commit.ci

* Add warning if numbers of variables in vars does not equal number of model variables

* Update version.txt

* `shock_size` should never be scalar

* Blackjax API change

* Handle latest PyMC/PyTensor breaking changes

* Temporarily mark two tests as xfail

* More bugfixes for statespace (#346)

* Allow forward sampling of statespace models in JAX mode

Explicitly set data shape to avoid broadcasting error

Better handling of measurement error dims in `SARIMAX` models

Freeze auxiliary models before forward sampling

Bugfixes for posterior predictive sampling helpers

Allow specification of time dimension name when registering data

Save info about exogenous data for post-estimation tasks

Restore `_exog_data_info` member variable

Be more consistent with the names of filter outputs

* Adjust test suite to reflect API changes

Modify structural tests to accommodate deterministic models

Save kalman filter outputs to idata for statespace tests

Remove test related to `add_exogenous`

Adjust structural module tests

* Add JAX test suite

* Bug-fixes and changes to statespace distributions

Remove tests related to the `add_exogenous` method

Add dummy `MvNormalSVDRV` for forward jax sampling with `method="SVD"`

Dynamically generate `LinearGaussianStateSpaceRV` signature from inputs

Add signature and simple test for `SequenceMvNormal`

* Re-run example notebooks

* Add helper function to sample prior/posterior statespace matrices

* fix tests

* Wrap jax MvNormal rewrite in try/except block

* Don't use `action` keyword in `catch_warnings`

* Skip JAX test if `numpyro` is not installed

* Handle batch dims on `SequenceMvNormal`

* Remove unused batch_dim logic in SequenceMvNormal

* Restore `get_support_shape_1d` import

* Fix failing test case for laplace

---------

Co-authored-by: pre-commit-ci[bot] <66853113+pre-commit-ci[bot]@users.noreply.github.com>
Co-authored-by: Ricardo Vieira <[email protected]>
Co-authored-by: Jesse Grabowski <[email protected]>
Co-authored-by: Ricardo Vieira <[email protected]>
Co-authored-by: Jesse Grabowski <[email protected]>
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

2 participants