From 743a657516bcb66354153a755a2ec123d2a5127a Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 22 Aug 2019 10:32:23 -0300 Subject: [PATCH 1/2] reduce number of logp evaluations II --- pymc3/smc/smc.py | 31 ++++++++++++++++++++++++------- pymc3/smc/smc_utils.py | 11 +++++++---- 2 files changed, 31 insertions(+), 11 deletions(-) diff --git a/pymc3/smc/smc.py b/pymc3/smc/smc.py index 95e7c6e9f9c..a79a6ec46fe 100644 --- a/pymc3/smc/smc.py +++ b/pymc3/smc/smc.py @@ -191,10 +191,14 @@ def sample_smc( if parallel and cores > 1: pool = mp.Pool(processes=cores) - results = pool.starmap(likelihood_logp, [(sample,) for sample in posterior]) + priors = pool.starmap(prior_logp, [(sample,) for sample in posterior]) + likelihoods = pool.starmap(likelihood_logp, [(sample,) for sample in posterior]) else: - results = [likelihood_logp(sample) for sample in posterior] - likelihoods = np.array(results).squeeze() + priors = [prior_logp(sample) for sample in posterior] + likelihoods = [likelihood_logp(sample) for sample in posterior] + + priors = np.array(priors).squeeze() + likelihoods = np.array(likelihoods).squeeze() while beta < 1: beta, old_beta, weights, sj = calc_beta(beta, likelihoods, threshold) @@ -203,6 +207,7 @@ def sample_smc( # resample based on plausibility weights (selection) resampling_indexes = np.random.choice(np.arange(draws), size=draws, p=weights) posterior = posterior[resampling_indexes] + priors = priors[resampling_indexes] likelihoods = likelihoods[resampling_indexes] # compute proposal distribution based on weights @@ -219,7 +224,6 @@ def sample_smc( pm._log.info("Stage: {:3d} Beta: {:.3f} Steps: {:3d}".format(stage, beta, n_steps)) # Apply Metropolis kernel (mutation) proposed = draws * n_steps - priors = np.array([prior_logp(sample) for sample in posterior]).squeeze() tempered_logp = priors + likelihoods * beta parameters = ( @@ -238,18 +242,31 @@ def sample_smc( results = pool.starmap( metrop_kernel, [ - (posterior[draw], tempered_logp[draw], likelihoods[draw], *parameters) + ( + posterior[draw], + tempered_logp[draw], + priors[draw], + likelihoods[draw], + *parameters, + ) for draw in range(draws) ], ) else: results = [ - metrop_kernel(posterior[draw], tempered_logp[draw], likelihoods[draw], *parameters) + metrop_kernel( + posterior[draw], + tempered_logp[draw], + priors[draw], + likelihoods[draw], + *parameters + ) for draw in tqdm(range(draws), disable=not progressbar) ] - posterior, acc_list, likelihoods = zip(*results) + posterior, acc_list, priors, likelihoods = zip(*results) posterior = np.array(posterior) + priors = np.array(priors) likelihoods = np.array(likelihoods) acc_rate = sum(acc_list) / proposed stage += 1 diff --git a/pymc3/smc/smc_utils.py b/pymc3/smc/smc_utils.py index 2a4f704da81..cf2f81b1948 100644 --- a/pymc3/smc/smc_utils.py +++ b/pymc3/smc/smc_utils.py @@ -110,6 +110,7 @@ def _posterior_to_trace(posterior, variables, model, var_info): def metrop_kernel( q_old, old_tempered_logp, + old_prior, old_likelihood, proposal, scaling, @@ -137,21 +138,23 @@ def metrop_kernel( q_new = (q_old + delta).astype("int64") else: delta[discrete] = np.round(delta[discrete], 0) - q_new = floatX(q_old + delta) + q_new = q_old + delta else: - q_new = floatX(q_old + delta) + q_new = q_old + delta ll = likelihood_logp(q_new) + pl = prior_logp(q_new) - new_tempered_logp = prior_logp(q_new) + ll * beta + new_tempered_logp = pl + ll * beta q_old, accept = metrop_select(new_tempered_logp - old_tempered_logp, q_new, q_old) if accept: accepted += 1 + old_prior = pl old_likelihood = ll old_tempered_logp = new_tempered_logp - return q_old, accepted, old_likelihood + return q_old, accepted, old_prior, old_likelihood def calc_beta(beta, likelihoods, threshold=0.5, psis=True): From 3ca66e8867988697ec1fb3ce780564f2d7e097c2 Mon Sep 17 00:00:00 2001 From: aloctavodia Date: Thu, 22 Aug 2019 11:21:00 -0300 Subject: [PATCH 2/2] floatX --- pymc3/smc/smc_utils.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pymc3/smc/smc_utils.py b/pymc3/smc/smc_utils.py index cf2f81b1948..e595fb92498 100644 --- a/pymc3/smc/smc_utils.py +++ b/pymc3/smc/smc_utils.py @@ -138,9 +138,9 @@ def metrop_kernel( q_new = (q_old + delta).astype("int64") else: delta[discrete] = np.round(delta[discrete], 0) - q_new = q_old + delta + q_new = floatX(q_old + delta) else: - q_new = q_old + delta + q_new = floatX(q_old + delta) ll = likelihood_logp(q_new) pl = prior_logp(q_new)