diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index f3dafb2bd03..9f4dbbe9e5e 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -13,6 +13,7 @@ - Add sampler stats `process_time_diff`, `perf_counter_diff` and `perf_counter_start`, that record wall and CPU times for each NUTS and HMC sample (see [ #3986](https://github.com/pymc-devs/pymc3/pull/3986)). - Extend `keep_size` argument handling for `sample_posterior_predictive` and `fast_sample_posterior_predictive`, to work on arviz InferenceData and xarray Dataset input values. (see [PR #4006](https://github.com/pymc-devs/pymc3/pull/4006) and [Issue #4004](https://github.com/pymc-devs/pymc3/issues/4004). - SMC-ABC: add the wasserstein and energy distance functions. Refactor API, the distance, sum_stats and epsilon arguments are now passed `pm.Simulator` instead of `pm.sample_smc`. Add random method to `pm.Simulator`. Add option to save the simulated data. Improves LaTeX representation [#3996](https://github.com/pymc-devs/pymc3/pull/3996) +- SMC-ABC: Allow use of potentials by adding them to the prior term. [#4016](https://github.com/pymc-devs/pymc3/pull/4016) ## PyMC3 3.9.2 (24 June 2020) ### Maintenance diff --git a/pymc3/smc/sample_smc.py b/pymc3/smc/sample_smc.py index caa366346e7..33b297c2269 100644 --- a/pymc3/smc/sample_smc.py +++ b/pymc3/smc/sample_smc.py @@ -137,6 +137,7 @@ def sample_smc( _log = logging.getLogger("pymc3") _log.info("Initializing SMC sampler...") + model = modelcontext(model) if cores is None: cores = _cpu_count() @@ -165,8 +166,10 @@ def sample_smc( if kernel.lower() == "abc": warnings.warn(EXPERIMENTAL_WARNING) - if len(modelcontext(model).observed_RVs) != 1: + if len(model.observed_RVs) != 1: warnings.warn("SMC-ABC only works properly with models with one observed variable") + if model.potentials: + _log.info("Potentials will be added to the prior term") params = ( draws, diff --git a/pymc3/smc/smc.py b/pymc3/smc/smc.py index 264ffb2779e..1e1e58c92d3 100644 --- a/pymc3/smc/smc.py +++ b/pymc3/smc/smc.py @@ -17,6 +17,7 @@ import numpy as np from scipy.special import logsumexp from theano import function as theano_function +import theano.tensor as tt from ..model import modelcontext, Point from ..theanof import floatX, inputvars, make_shared_replacements, join_nonshared_inputs @@ -100,9 +101,11 @@ def setup_kernel(self): Set up the likelihood logp function based on the chosen kernel """ shared = make_shared_replacements(self.variables, self.model) - self.prior_logp_func = logp_forw([self.model.varlogpt], self.variables, shared) if self.kernel.lower() == "abc": + factors = [var.logpt for var in self.model.free_RVs] + factors += [tt.sum(factor) for factor in self.model.potentials] + self.prior_logp_func = logp_forw([tt.sum(factors)], self.variables, shared) simulator = self.model.observed_RVs[0] distance = simulator.distribution.distance sum_stat = simulator.distribution.sum_stat @@ -120,6 +123,7 @@ def setup_kernel(self): self.save_sim_data, ) elif self.kernel.lower() == "metropolis": + self.prior_logp_func = logp_forw([self.model.varlogpt], self.variables, shared) self.likelihood_logp_func = logp_forw([self.model.datalogpt], self.variables, shared) def initialize_logp(self): diff --git a/pymc3/tests/test_smc.py b/pymc3/tests/test_smc.py index d88fcaf18dd..f2b682e5ca8 100644 --- a/pymc3/tests/test_smc.py +++ b/pymc3/tests/test_smc.py @@ -130,6 +130,14 @@ def abs_diff(eps, obs_data, sim_data): observed=self.data, ) + with pm.Model() as self.SMABC_potential: + a = pm.Normal("a", mu=0, sigma=1) + b = pm.HalfNormal("b", sigma=1) + c = pm.Potential("c", pm.math.switch(a > 0, 0, -np.inf)) + s = pm.Simulator( + "s", normal_sim, params=(a, b), sum_stat="sort", epsilon=1, observed=self.data + ) + def test_one_gaussian(self): with self.SMABC_test: trace = pm.sample_smc(draws=1000, kernel="ABC") @@ -157,6 +165,11 @@ def test_custom_dist_sum(self): with self.SMABC_test2: trace = pm.sample_smc(draws=1000, kernel="ABC") + def test_potential(self): + with self.SMABC_potential: + trace = pm.sample_smc(draws=1000, kernel="ABC") + assert np.all(trace["a"] >= 0) + def test_automatic_use_of_sort(self): with pm.Model() as model: s_g = pm.Simulator(