-
-
Notifications
You must be signed in to change notification settings - Fork 2k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Use jaxified logp for initial point evaluation when sampling via Jax #7610
base: main
Are you sure you want to change the base?
Conversation
] |
|
@ricardoV94 any thoughts on this one? |
pymc/sampling/jax.py
Outdated
model_logp = model.logp() | ||
if not negative_logp: | ||
model_logp = -model_logp | ||
logp_fn = get_jaxified_graph(inputs=model.value_vars, outputs=[model_logp]) | ||
|
||
def logp_fn_wrap(x): | ||
def logp_fn_wrap(x: Sequence[np.ndarray]) -> np.ndarray: |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This is not correct, it takes jax arrays and outputs jax arrays
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I don't think that's 100% true. Checking with the interactive debugger confirms that the return type is jax.Array
, but the initial point functions return a dict[str, np.ndarray]
, and we can successfully pass the .values()
of that dict into the jaxified function. So it can seemingly accept anything that's coercible to an array. Maybe it's more correct to annotate it like this:
def logp_fn_wrap(x: ArrayLike) -> jax.Array:
ArrayLike is from numpy.typing
: https://numpy.org/devdocs/reference/typing.html#numpy.typing.ArrayLike
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
I've just pushed a commit to improve this, it's a bit tricky to annotate at the interface with _init_jitter
given that jax
is an optional dependency. I've left the type annotation as returning a np.ndarray
but included that it may return a jax.Array
in the docstring.
@nataziel good catch. The outer/user-facing functions take initvals which later get converted into |
Cleaned that up and added docstrings for the numpyro equivalent :) |
Not sure why the most recent commit didn't trigger the documentation check. edit: it wasn't using the locally installed version. With the local version properly installed it failed. Will try debug |
Seems sphinx autodoc didn't like |
@ricardoV94, anything you need from my end to get this one moving? |
@ricardoV94 sorry to be a pain, I'm really keen to get this implemented as 5.19 is a huge speedup to my model but I can't use it with the current initial point implementation. Anything I can do or change to get this through? |
I'll review it today, apologies for the delay |
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Left some small comments regarding type-hints and docstrings. Otherwise the main functionality looks perfect and ready to merge
No need to apologise on the delay, the holidays are always a busy time! Have implemented most of that feedback but let me know what you think re: default values in jax sampler docstrings |
Codecov ReportAttention: Patch coverage is
Additional details and impacted files@@ Coverage Diff @@
## main #7610 +/- ##
==========================================
- Coverage 92.83% 92.78% -0.06%
==========================================
Files 106 107 +1
Lines 17748 18189 +441
==========================================
+ Hits 16477 16876 +399
- Misses 1271 1313 +42
|
The missing line of coverage is for a defensive error that should never occur. It's a line I moved, and was not covered before this change. |
I can't see where anything I've touched could cause the test failure, there is code later on in the test that references initial points but it fails at L806 before any of that occurs. I can't replicate the failing test on my (windows) machine, that entire TESTGARCH11 test class runs fine locally for me. I can see it has been flaky in the past but it's (partially) seeded now so I'm not sure why it's failed here. Maybe the generation of |
Use jaxified logp for initial point evaluation when sampling via Jax
Description
sample_jax_nuts
nuts_sampler
is specified_get_batched_jittered_initial_points
logp_fn
parameter to function signature_init_jitter
will call it_init_jitter
logp_fn
parameter to function signature_init_jitter
to decide which function to use to evaluate the generated pointsRelated Issue
pt.config.floatX = "float32"
#7608Checklist
sample_blackjax_nuts
function docstring, will put in comment belowType of change
📚 Documentation preview 📚: https://pymc--7610.org.readthedocs.build/en/7610/