Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Use jaxified logp for initial point evaluation when sampling via Jax #7610

Open
wants to merge 11 commits into
base: main
Choose a base branch
from

Conversation

nataziel
Copy link

@nataziel nataziel commented Dec 11, 2024

Use jaxified logp for initial point evaluation when sampling via Jax

Description

  • get jaxified logp function in sample_jax_nuts
    • uses different parameters to get jaxified function depending on which nuts_sampler is specified
  • pass jaxified logp function to _get_batched_jittered_initial_points
    • added logp_fn parameter to function signature
    • wrap passed function to conform to how _init_jitter will call it
  • pass wrapped function to _init_jitter
    • added logp_fn parameter to function signature
  • added logic in _init_jitter to decide which function to use to evaluate the generated points
  • added a bunch of type annotations

Related Issue

Checklist

  • Checked that the pre-commit linting/style checks pass
  • Included tests that prove the fix is effective or that the new feature works
  • Added necessary documentation (docstrings and/or example notebooks) - have a question about the sample_blackjax_nuts function docstring, will put in comment below
  • If you are a pro: each commit corresponds to a [relevant logical change]

Type of change

  • New feature / enhancement
  • Bug fix
  • Documentation
  • Maintenance
  • Other (please specify):

📚 Documentation preview 📚: https://pymc--7610.org.readthedocs.build/en/7610/

Copy link

welcome bot commented Dec 11, 2024

Thank You Banner]
💖 Thanks for opening this pull request! 💖 The PyMC community really appreciates your time and effort to contribute to the project. Please make sure you have read our Contributing Guidelines and filled in our pull request template to the best of your ability.

@github-actions github-actions bot added the bug label Dec 11, 2024
@nataziel
Copy link
Author

_sample_blackjax_nuts mentions initvals in the parameters section of the docstrings, is that a kwarg or should it be changed to initial_points?
I think it's the former and I could add an initial_points parameter to the docstring? Happy to do if that's correct.

pymc/initial_point.py Outdated Show resolved Hide resolved
pymc/sampling/mcmc.py Outdated Show resolved Hide resolved
@nataziel
Copy link
Author

_sample_blackjax_nuts mentions initvals in the parameters section of the docstrings, is that a kwarg or should it be changed to initial_points? I think it's the former and I could add an initial_points parameter to the docstring? Happy to do if that's correct.

@ricardoV94 any thoughts on this one?

model_logp = model.logp()
if not negative_logp:
model_logp = -model_logp
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp])

def logp_fn_wrap(x):
def logp_fn_wrap(x: Sequence[np.ndarray]) -> np.ndarray:
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is not correct, it takes jax arrays and outputs jax arrays

Copy link
Author

@nataziel nataziel Dec 12, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I don't think that's 100% true. Checking with the interactive debugger confirms that the return type is jax.Array, but the initial point functions return a dict[str, np.ndarray], and we can successfully pass the .values() of that dict into the jaxified function. So it can seemingly accept anything that's coercible to an array. Maybe it's more correct to annotate it like this:

def logp_fn_wrap(x: ArrayLike) -> jax.Array:

ArrayLike is from numpy.typing: https://numpy.org/devdocs/reference/typing.html#numpy.typing.ArrayLike

Copy link
Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I've just pushed a commit to improve this, it's a bit tricky to annotate at the interface with _init_jitter given that jax is an optional dependency. I've left the type annotation as returning a np.ndarray but included that it may return a jax.Array in the docstring.

@ricardoV94
Copy link
Member

@nataziel good catch. The outer/user-facing functions take initvals which later get converted into initial_points in the inner functions. Feel free to update the docstrings of the inner functions if they still refer to initvals

@nataziel
Copy link
Author

@nataziel good catch. The outer/user-facing functions take initvals which later get converted into initial_points in the inner functions. Feel free to update the docstrings of the inner functions if they still refer to initvals

Cleaned that up and added docstrings for the numpyro equivalent :)

@nataziel
Copy link
Author

nataziel commented Dec 12, 2024

Not sure why the most recent commit didn't trigger the documentation check. I was able to run make rtd locally and the build succeeded.

edit: it wasn't using the locally installed version. With the local version properly installed it failed. Will try debug

@nataziel
Copy link
Author

Seems sphinx autodoc didn't like if TYPE_CHECKING. Successfully builds with make rtd on my machine now. Not sure how to manually trigger the check

@nataziel nataziel requested a review from ricardoV94 December 12, 2024 13:22
@nataziel
Copy link
Author

nataziel commented Jan 6, 2025

@ricardoV94, anything you need from my end to get this one moving?

@nataziel
Copy link
Author

@ricardoV94 sorry to be a pain, I'm really keen to get this implemented as 5.19 is a huge speedup to my model but I can't use it with the current initial point implementation. Anything I can do or change to get this through?

@ricardoV94
Copy link
Member

@ricardoV94 sorry to be a pain, I'm really keen to get this implemented as 5.19 is a huge speedup to my model but I can't use it with the current initial point implementation. Anything I can do or change to get this through?

I'll review it today, apologies for the delay

pymc/model/core.py Outdated Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
pymc/sampling/jax.py Outdated Show resolved Hide resolved
Copy link
Member

@ricardoV94 ricardoV94 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Left some small comments regarding type-hints and docstrings. Otherwise the main functionality looks perfect and ready to merge

@nataziel
Copy link
Author

No need to apologise on the delay, the holidays are always a busy time! Have implemented most of that feedback but let me know what you think re: default values in jax sampler docstrings

Copy link

codecov bot commented Jan 17, 2025

Codecov Report

Attention: Patch coverage is 95.65217% with 1 line in your changes missing coverage. Please review.

Project coverage is 92.78%. Comparing base (6cdfc30) to head (deea64c).
Report is 25 commits behind head on main.

Files with missing lines Patch % Lines
pymc/sampling/jax.py 95.23% 1 Missing ⚠️
Additional details and impacted files

Impacted file tree graph

@@            Coverage Diff             @@
##             main    #7610      +/-   ##
==========================================
- Coverage   92.83%   92.78%   -0.06%     
==========================================
  Files         106      107       +1     
  Lines       17748    18189     +441     
==========================================
+ Hits        16477    16876     +399     
- Misses       1271     1313      +42     
Files with missing lines Coverage Δ
pymc/initial_point.py 99.02% <ø> (ø)
pymc/sampling/mcmc.py 87.16% <100.00%> (+1.04%) ⬆️
pymc/sampling/jax.py 95.02% <95.23%> (+0.23%) ⬆️

... and 13 files with indirect coverage changes

@nataziel
Copy link
Author

nataziel commented Jan 17, 2025

The missing line of coverage is for a defensive error that should never occur. It's a line I moved, and was not covered before this change.

@nataziel
Copy link
Author

I can't see where anything I've touched could cause the test failure, there is code later on in the test that references initial points but it fails at L806 before any of that occurs.

I can't replicate the failing test on my (windows) machine, that entire TESTGARCH11 test class runs fine locally for me. I can see it has been flaky in the past but it's (partially) seeded now so I'm not sure why it's failed here. Maybe the generation of param_val on L788 should be seeded as well to make it less flaky, but I think it's a separate issue.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Projects
None yet
Development

Successfully merging this pull request may close these issues.

BUG: model initial_point fails when pt.config.floatX = "float32"
2 participants