Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Customize MCMC algorithm #1632

Closed
DQSSSSS opened this issue Aug 24, 2023 · 5 comments
Closed

Customize MCMC algorithm #1632

DQSSSSS opened this issue Aug 24, 2023 · 5 comments
Labels
question Further information is requested

Comments

@DQSSSSS
Copy link

DQSSSSS commented Aug 24, 2023

Hi all, thanks for your powerful library.
Sorry for my beginner question. Now I have implemented a Metropolis-Hasting algorithm following this link, I need to customize the score calculation.
My code:

from collections import namedtuple
import os, time
import jax
import copy
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, BarkerMH

MHState = namedtuple("MHState", ["u", "score", "rng_key"])

class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
    sample_field = "u"

    def __init__(self, potential_fn, translation_fn, beta=100, step_size=0.1):
        self.potential_fn = potential_fn
        self.translation_fn = translation_fn
        self.step_size = step_size
        self.beta = beta

    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        u = init_params
        return MHState(u, self.potential_fn(u, **model_kwargs["potential_fn_kwargs"]), rng_key)

    def sample(self, state, model_args, model_kwargs):
        u, score, rng_key = state
        rng_key, key_accept, key_translation = random.split(rng_key, 3)
        u_proposal = self.translation_fn(u, key_translation, self.step_size, **model_kwargs["translation_fn_kwargs"])
        score_new = self.potential_fn(u_proposal, **model_kwargs["potential_fn_kwargs"])
        accept_prob = jnp.exp(self.beta*(score - score_new)) # exp(-beta*s_next)/exp(-beta*s_now)
        alpha = dist.Uniform().sample(key_accept)
        u_new  = jnp.where(alpha < accept_prob, u_proposal, u)
        score_new  = jnp.where(alpha < accept_prob, score_new, score)
        return MHState(u_new, score_new, rng_key)

def potential_fn(params, constraints, verbose, print_func=None):
    # do something to calc the score
    return score

def translation_fn(params, key, step_size):
    # do something to translate params
    # it is a symmetric translation, so I can use the simple formulation of accept probability in Metropolis-Hastings algorithm
    return params_new

kernel = MetropolisHastings(potential_fn=potential_fn, translation_fn=translation_fn)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=10000, num_chains=1, progress_bar=True)
# trainable_params: the params to be optimized
# constrains: the constraints of the params, it is used to calc the score in function `potential_fn`
# verbose: whether to print the debug info in function `potential_fn`
mcmc.run(random.PRNGKey(0), init_params=trainable_params, extra_fields=('score',),
        potential_fn_kwargs=dict(constrains=constrains, verbose=False),
        translation_fn_kwargs=dict(),
)

It works, but this algorithm is low-performance. I found that the numpyro has so many amazing random algorithms(https://num.pyro.ai/en/stable/mcmc.html), I want to replace this simple MH algorithm with them, how can I do it?

@DQSSSSS
Copy link
Author

DQSSSSS commented Aug 24, 2023

I found that this code maybe work? But I don't know how to use my translation_fn

from collections import namedtuple
import os, time
import jax
import copy
from jax import random
import jax.numpy as jnp
import numpyro
import numpyro.distributions as dist
from numpyro.infer import MCMC, BarkerMH, NUTS

MHState = namedtuple("MHState", ["u", "score", "rng_key"])

class MetropolisHastings(numpyro.infer.mcmc.MCMCKernel):
    sample_field = "u"

    def __init__(self, potential_fn, translation_fn, beta=100, step_size=0.1):
        self.potential_fn = potential_fn
        self.translation_fn = translation_fn
        self.step_size = step_size
        self.beta = beta

    def init(self, rng_key, num_warmup, init_params, model_args, model_kwargs):
        u = init_params
        return MHState(u, self.potential_fn(u, **model_kwargs["potential_fn_kwargs"]), rng_key)

    def sample(self, state, model_args, model_kwargs):
        u, score, rng_key = state
        rng_key, key_accept, key_translation = random.split(rng_key, 3)
        u_proposal = self.translation_fn(u, key_translation, self.step_size, **model_kwargs["translation_fn_kwargs"])
        score_new = self.potential_fn(u_proposal, **model_kwargs["potential_fn_kwargs"])
        accept_prob = jnp.exp(self.beta*(score - score_new)) # exp(-beta*s_next)/exp(-beta*s_now)
        alpha = dist.Uniform().sample(key_accept)
        u_new  = jnp.where(alpha < accept_prob, u_proposal, u)
        score_new  = jnp.where(alpha < accept_prob, score_new, score)
        return MHState(u_new, score_new, rng_key)

constraints = None # claims as global variable

def potential_fn(params, verbose=False, print_func=None):
    # do something to calc the score
    return score

# def translation_fn(params, key, step_size):
#     # do something to translate params
#     # it is a symmetric translation, so I can use the simple formulation of accept probability in Metropolis-Hastings algorithm
#     return params_new

kernel = NUTS(potential_fn=potential_fn)
mcmc = MCMC(kernel, num_warmup=1000, num_samples=10000, num_chains=1, progress_bar=True)
# trainable_params: the params to be optimized
# constrains: the constraints of the params, it is used to calc the score in function `potential_fn`
# verbose: whether to print the debug info in function `potential_fn`
mcmc.run(random.PRNGKey(0), init_params=trainable_params, extra_fields=('potential_energy',))
scores = mcmc.get_extra_fields()['potential_energy']
idx = jnp.argmin(scores)
result = mcmc.get_samples()['params'][idx]

@fehiepsi fehiepsi added the question Further information is requested label Aug 28, 2023
@fehiepsi
Copy link
Member

Hi @DQSSSSS, could you clarify your question? Is the code not working or something?

@DQSSSSS
Copy link
Author

DQSSSSS commented Aug 28, 2023

Hi @fehiepsi, sorry for the ambiguity. This code is working but it uses the simple MH algorithm, I want to use some amazing algorithms in numpyro such as NUTS, HMCECS, etc. and select the best one.
I don't know how to define my translation function translation_fn and use NUTS at the same time, I have tried my best but I got the code like my second comment.
Using the default translation function makes my algorithm run slow because of the complexity of the problem so I need it.

@fehiepsi
Copy link
Member

Because your translation function requires random key, I think you need to wrap the NUTS kernel for such logic. You can see HMCGibbs implementation for a design. We have a pending issue #898 for composing the kernels.

@fehiepsi
Copy link
Member

Closed. Please feel free to follow up the discussion on the forum: https://forum.pyro.ai/

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
question Further information is requested
Projects
None yet
Development

No branches or pull requests

2 participants