Adaptive SMC as in pyMC? #735
Replies: 2 comments
-
Hei @jucor nice to meet you and thanks for your interest in all these!.
Apologies on this one. I don't super understand the question. The notebook does make use of inner_kernel_tuning and some primitives that can be found inside https://github.com/blackjax-devs/blackjax/tree/main/blackjax/smc/tuning there to compute means, stds, covariance matrixes, etc from particles. Regarding the inference loop, I think the core contributors decision has always been not to include them in the library. LMK is this is what you meant or I am missing some part of the question. See Note at the bottom regarding pre-tuning. Also LMK if you had in mind a different parameter tuning procedure.
Regarding the number of MCMC steps. Based on convos with @nchopin, Waste-Free seemed like a better option. #721. Although is not the only reason, WF is better suited for Jax since the number of steps is the same on all chains. Indeed, blackjax.smc.adaptive_tempered implements the temperature schedule.
Right, pymc already has things ready out of the box. If using Blackjax you would need to assemble the pieces. You might be interested in the code https://github.com/pymc-devs/pymc-experimental/blob/main/pymc_experimental/inference/smc/sampling.py which I implemented some time ago. The only caveat is that there's a current issue with pytensor blocking merging this fix PR: pymc-devs/pymc-extras#374 . Second caveat is that the parameter tuning would need improvement (see below) It is in my near future roadmap to implement the pre-tunning approach considered in the paper you cited. Right now you can only tweak parameters based on the particles or the sampler state. With the pre-tuning algorithm, an MCMC step would be executed, performance of parameters would be measured, and we would probabilistically select a population of parameters suited for the distribution at step T. With pre-tunning, my understanding is that all the pieces would be in place for the Blackjax implementation to be competitive with state of the art when the goal is sampling from a bayesian posterior distribution.
Hope it helps! |
Beta Was this translation helpful? Give feedback.
-
Thanks a lot @ciguaran for the detailed answer, super helpful!
Thanks a lot for your super helpful answer, I'll now get to assemble the main inference loops. (If you know where I could read about the rationale for not including them in the library, I'd love to know!) Thanks! |
Beta Was this translation helpful? Give feedback.
-
Hi folks!
This question is probably best suited for @ciguaran who worked on the SMC implementation and the Tuning SMC notebook tuning the inner-SMC MCMC kernel, handling ), which handles Design Choice (c) of Section 2.1.3 of Buchholz, Chopin, Jacob (2020) for the case of an IRMH.
Do you have any plan to merge this implementation of the tuning of MCMC kernel into the blackjax lib itself please? To avoid copy/pasting/modifying from the specific case in the notebook? Easier to use ;-) I realise this might be
blackjax.smc.inner_kernel_tuning
, so would you have an example of usage please?I Have you implemented anywhere Design Choice (b), i.e. adapting the number of MCMC steps based on Algorithm 3, by any chance? I believe
blackjax.smc.adaptive_tempered
implements Algorithm 2 for Design Choice (a): the temperature schedule, based on ESS.Reason I'm asking: I'm moving a codebase from PyMC to Blackjax, to benefit from jax's
vmap
(each likelihood evaluation requires solving an ODE: costly, but vmap with diffrax makes it embarassingly parallelizable on GPU). PyMC's SMC sampler, while regrettably not GPU-parallelizable (at least that I know of), has all those adaptations already baked in, which makes it quite easy to use by applied statisticians who are not SMC specialists.It would be terrific to have those in blackjax, please :)
Thanks!
Beta Was this translation helpful? Give feedback.
All reactions