MCMC in parallel with different datasets #139
-
Hello! However I'm wondering if it's possible to run several MCMC in parallel, with each targettng the same posterior with a different dataset. Is this something that is possible with BlackJax, or does the code running on each device have to be the same? I'm asking as the kernel needs to be built before defining the sampler. So doing a vmap/pmap over different datasets might be possible, but the function being mapped would need to build the kernel.. Thanks! |
Beta Was this translation helpful? Give feedback.
Replies: 1 comment 1 reply
-
Yes this should be possible with vmapping the inference/sampling routine, something like: # define a function that runs inference on 1 dataset:
@jax.vmap
def run_inference(dataset):
... the kernel building will encapsulate within the |
Beta Was this translation helpful? Give feedback.
Yes this should be possible with vmapping the inference/sampling routine, something like:
the kernel building will encapsulate within the
run_inference
function