Using non-JAX-traceable logpdfs with BlackJAX #136
Replies: 5 comments 1 reply
-
Thanks for using blackjax! I'll let @mattjj answer the more general JAX question (you're actually not the first to ask). On the blackjax side, there is currently no way for users to provide their own gradient because of the way the symplectic integrators are implemented. However we are currently integrating the stochastic gradient methods from SGMCMCJAX; those require the user to pass a gradient estimator and we'll harmonize the APIs so users can also provide gradients for HMC/NUTS. |
Beta Was this translation helpful? Give feedback.
-
Yes, it seems like it's not difficult now to add support of arbitrary Python functions to work with |
Beta Was this translation helpful? Give feedback.
-
@sethaxen did you find a way in the end? |
Beta Was this translation helpful? Give feedback.
-
I'm positive this is possible following the links @IvanYashchuk provided and I am thinking about writing a "How to?" tutorial. Maybe using PyTorch? |
Beta Was this translation helpful? Give feedback.
-
It is possible, we just pushed an example in the documentation: https://blackjax-devs.github.io/blackjax/examples/howto_other_frameworks.html |
Beta Was this translation helpful? Give feedback.
-
Is it possible to use functions that are not traceable by JAX with BlackJAX? Some use cases:
In the first case, the object itself might live outside of the logpdf function but be updated internally by the logpdf function using the sampled parameters in order to compute the logpdf and its gradient. This at least puts constrains on parallelism.
Is this currently possible using BlackJAX? This might be more of a JAX question than a BlackJAX one.
cc @mattjj
Beta Was this translation helpful? Give feedback.
All reactions