-
Notifications
You must be signed in to change notification settings - Fork 246
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Customize MCMC algorithm #1632
Comments
I found that this code maybe work? But I don't know how to use my from collections import namedtuple
import os, time
import jax
import copy
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, BarkerMH, NUTS
MHState = namedtuple("MHState", ["u", "score", "rng_key"])
class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
sample_field = "u"
def __init__(self, potential_fn, translation_fn, beta=100, step_size=0.1):
self.potential_fn = potential_fn
self.translation_fn = translation_fn
self.step_size = step_size
self.beta = beta
def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
u = init_params
return MHState(u, self.potential_fn(u, **model_kwargs["potential_fn_kwargs"]), rng_key)
def sample(self, state, model_args, model_kwargs):
u, score, rng_key = state
rng_key, key_accept, key_translation = random.split(rng_key, 3)
u_proposal = self.translation_fn(u, key_translation, self.step_size, **model_kwargs["translation_fn_kwargs"])
score_new = self.potential_fn(u_proposal, **model_kwargs["potential_fn_kwargs"])
accept_prob = jnp.exp(self.beta*(score - score_new)) # exp(-beta*s_next)/exp(-beta*s_now)
alpha = dist.Uniform().sample(key_accept)
u_new = jnp.where(alpha < accept_prob, u_proposal, u)
score_new = jnp.where(alpha < accept_prob, score_new, score)
return MHState(u_new, score_new, rng_key)
constraints = None # claims as global variable
def potential_fn(params, verbose=False, print_func=None):
# do something to calc the score
return score
# def translation_fn(params, key, step_size):
# # do something to translate params
# # it is a symmetric translation, so I can use the simple formulation of accept probability in Metropolis-Hastings algorithm
# return params_new
kernel = NUTS(potential_fn=potential_fn)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=10000, num_chains=1, progress_bar=True)
# trainable_params: the params to be optimized
# constrains: the constraints of the params, it is used to calc the score in function `potential_fn`
# verbose: whether to print the debug info in function `potential_fn`
mcmc.run(random.PRNGKey(0), init_params=trainable_params, extra_fields=('potential_energy',))
scores = mcmc.get_extra_fields()['potential_energy']
idx = jnp.argmin(scores)
result = mcmc.get_samples()['params'][idx] |
Hi @DQSSSSS, could you clarify your question? Is the code not working or something? |
Hi @fehiepsi, sorry for the ambiguity. This code is working but it uses the simple MH algorithm, I want to use some amazing algorithms in numpyro such as NUTS, HMCECS, etc. and select the best one. |
Because your translation function requires random key, I think you need to wrap the NUTS kernel for such logic. You can see HMCGibbs implementation for a design. We have a pending issue #898 for composing the kernels. |
Closed. Please feel free to follow up the discussion on the forum: https://forum.pyro.ai/ |
Hi all, thanks for your powerful library.
Sorry for my beginner question. Now I have implemented a Metropolis-Hasting algorithm following this link, I need to customize the score calculation.
My code:
It works, but this algorithm is low-performance. I found that the numpyro has so many amazing random algorithms(https://num.pyro.ai/en/stable/mcmc.html), I want to replace this simple MH algorithm with them, how can I do it?
The text was updated successfully, but these errors were encountered: