Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tuning by skipping the first samples + add new experimental tuning method #5004

Merged
merged 12 commits into from
Sep 22, 2021
2 changes: 2 additions & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,8 @@
- The `OrderedMultinomial` distribution has been added for use on ordinal data which are _aggregated_ by trial, like multinomial observations, whereas `OrderedLogistic` only accepts ordinal data in a _disaggregated_ format, like categorical
observations (see [#4773](https://github.com/pymc-devs/pymc3/pull/4773)).
- The `Polya-Gamma` distribution has been added (see [#4531](https://github.com/pymc-devs/pymc3/pull/4531)). To make use of this distribution, the [`polyagamma>=1.3.1`](https://pypi.org/project/polyagamma/) library must be installed and available in the user's environment.
- A small change to the mass matrix tuning methods jitter+adapt_diag (the default) and adapt_diag improves performance early on during tuning for some models. [#5004](https://github.com/pymc-devs/pymc3/pull/5004)
- New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc3/pull/5004)
- ...

### Maintenance
Expand Down
62 changes: 23 additions & 39 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -287,25 +287,7 @@ def sample(
by default. See ``discard_tuned_samples``.
init : str
Initialization method to use for auto-assigned NUTS samplers.

* auto: Choose a default initialization method automatically.
Currently, this is ``jitter+adapt_diag``, but this can change in the future.
If you depend on the exact behaviour, choose an initialization method explicitly.
* adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the
variance of the tuning samples. All chains use the test value (usually the prior mean)
as starting point.
* jitter+adapt_diag: Same as ``adapt_diag``, but add uniform jitter in [-1, 1] to the
starting point in each chain.
* advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the
sample variance of the tuning samples.
* advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based
on the variance of the gradients during tuning. This is **experimental** and might be
removed in a future release.
* advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
* map: Use the MAP as starting point. This is discouraged.
* adapt_full: Adapt a dense mass matrix using the sample covariances

See `pm.init_nuts` for a list of all options.
step : function or iterable of functions
A step function or collection of functions. If there are variables without step methods,
step methods for those variables will be assigned automatically. By default the NUTS step
Expand Down Expand Up @@ -516,6 +498,7 @@ def sample(
random_seed=random_seed,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
tune=tune,
**kwargs,
)
if start is None:
Expand Down Expand Up @@ -2078,6 +2061,7 @@ def init_nuts(
random_seed=None,
progressbar=True,
jitter_max_retries=10,
tune=None,
**kwargs,
):
"""Set up the mass matrix initialization for NUTS.
Expand All @@ -2099,11 +2083,11 @@ def init_nuts(
as starting point.
* jitter+adapt_diag: Same as ``adapt_diag``, but use test value plus a uniform jitter in
[-1, 1] as starting point in each chain.
* jitter+adapt_diag_grad:
An experimental initialization method that uses information from gradients and samples
during tuning.
* advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the
sample variance of the tuning samples.
* advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based
on the variance of the gradients during tuning. This is **experimental** and might be
removed in a future release.
* advi: Run ADVI to estimate posterior mean and diagonal mass matrix.
* advi_map: Initialize ADVI with MAP and use MAP as starting point.
* map: Use the MAP as starting point. This is discouraged.
Expand Down Expand Up @@ -2174,24 +2158,24 @@ def init_nuts(
var = np.ones_like(mean)
n = len(var)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
elif init == "advi+adapt_diag_grad":
approx: pm.MeanField = pm.fit(
random_seed=random_seed,
n=n_init,
method="advi",
model=model,
callbacks=cb,
progressbar=progressbar,
obj_optimizer=pm.adagrad_window,
elif init == "jitter+adapt_diag_grad":
start = _init_jitter(model, model.initial_point, chains, jitter_max_retries)
mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0)
var = np.ones_like(mean)
n = len(var)

if tune is not None and tune > 250:
stop_adaptation = tune - 50
else:
stop_adaptation = None

potential = quadpotential.QuadPotentialDiagAdaptExp(
n,
mean,
alpha=0.02,
use_grads=True,
stop_adaptation=stop_adaptation,
)
start = approx.sample(draws=chains)
start = list(start)
std_apoint = approx.std.eval()
cov = std_apoint ** 2
mean = approx.mean.get_value()
weight = 50
n = len(cov)
potential = quadpotential.QuadPotentialDiagAdaptGrad(n, mean, cov, weight)
elif init == "advi+adapt_diag":
approx = pm.fit(
random_seed=random_seed,
Expand Down
6 changes: 2 additions & 4 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -253,9 +253,7 @@ def __init__(self, ndim, integrator, start, step_size, Emax):
self.start_energy = np.array(start.energy)

self.left = self.right = start
self.proposal = Proposal(
start.q.data, start.q_grad.data, start.energy, 1.0, start.model_logp
)
self.proposal = Proposal(start.q.data, start.q_grad, start.energy, 1.0, start.model_logp)
self.depth = 0
self.log_size = 0
self.log_weighted_accept_sum = -np.inf
Expand Down Expand Up @@ -350,7 +348,7 @@ def _single_step(self, left, epsilon):
log_size = -energy_change
proposal = Proposal(
right.q.data,
right.q_grad.data,
right.q_grad,
right.energy,
log_p_accept_weighted,
right.model_logp,
Expand Down
Loading