From 4f8ad5d2772c8fd98b403d7c1d141e8f8394ba84 Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 22 Sep 2021 02:53:11 +0200 Subject: [PATCH] Improve tuning by skipping the first samples + add new experimental tuning method (#5004) * Fix issue in hmc gradient storage * Skip first samples during NUTS adaptation * Add test and doc for jitter+adapt_diag_grad * Improve tests of init methods * Add new tuning method to release notes * Remove old gradient mass matrix adaptation * Remove weight argument in quadpotential add_sample --- RELEASE-NOTES.md | 2 + pymc3/sampling.py | 62 +++---- pymc3/step_methods/hmc/nuts.py | 6 +- pymc3/step_methods/hmc/quadpotential.py | 211 ++++++++++++++++++------ pymc3/tests/test_quadpotential.py | 10 +- pymc3/tests/test_sampling.py | 13 +- 6 files changed, 200 insertions(+), 104 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index cdf96b03796..e1f8705b033 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -23,6 +23,8 @@ - The `OrderedMultinomial` distribution has been added for use on ordinal data which are _aggregated_ by trial, like multinomial observations, whereas `OrderedLogistic` only accepts ordinal data in a _disaggregated_ format, like categorical observations (see [#4773](https://github.com/pymc-devs/pymc3/pull/4773)). - The `Polya-Gamma` distribution has been added (see [#4531](https://github.com/pymc-devs/pymc3/pull/4531)). To make use of this distribution, the [`polyagamma>=1.3.1`](https://pypi.org/project/polyagamma/) library must be installed and available in the user's environment. +- A small change to the mass matrix tuning methods jitter+adapt_diag (the default) and adapt_diag improves performance early on during tuning for some models. [#5004](https://github.com/pymc-devs/pymc3/pull/5004) +- New experimental mass matrix tuning method jitter+adapt_diag_grad. [#5004](https://github.com/pymc-devs/pymc3/pull/5004) - ... ### Maintenance diff --git a/pymc3/sampling.py b/pymc3/sampling.py index e8131ce642f..665071f6639 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -287,25 +287,7 @@ def sample( by default. See ``discard_tuned_samples``. init : str Initialization method to use for auto-assigned NUTS samplers. - - * auto: Choose a default initialization method automatically. - Currently, this is ``jitter+adapt_diag``, but this can change in the future. - If you depend on the exact behaviour, choose an initialization method explicitly. - * adapt_diag: Start with a identity mass matrix and then adapt a diagonal based on the - variance of the tuning samples. All chains use the test value (usually the prior mean) - as starting point. - * jitter+adapt_diag: Same as ``adapt_diag``, but add uniform jitter in [-1, 1] to the - starting point in each chain. - * advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the - sample variance of the tuning samples. - * advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based - on the variance of the gradients during tuning. This is **experimental** and might be - removed in a future release. - * advi: Run ADVI to estimate posterior mean and diagonal mass matrix. - * advi_map: Initialize ADVI with MAP and use MAP as starting point. - * map: Use the MAP as starting point. This is discouraged. - * adapt_full: Adapt a dense mass matrix using the sample covariances - + See `pm.init_nuts` for a list of all options. step : function or iterable of functions A step function or collection of functions. If there are variables without step methods, step methods for those variables will be assigned automatically. By default the NUTS step @@ -516,6 +498,7 @@ def sample( random_seed=random_seed, progressbar=progressbar, jitter_max_retries=jitter_max_retries, + tune=tune, **kwargs, ) if start is None: @@ -2078,6 +2061,7 @@ def init_nuts( random_seed=None, progressbar=True, jitter_max_retries=10, + tune=None, **kwargs, ): """Set up the mass matrix initialization for NUTS. @@ -2099,11 +2083,11 @@ def init_nuts( as starting point. * jitter+adapt_diag: Same as ``adapt_diag``, but use test value plus a uniform jitter in [-1, 1] as starting point in each chain. + * jitter+adapt_diag_grad: + An experimental initialization method that uses information from gradients and samples + during tuning. * advi+adapt_diag: Run ADVI and then adapt the resulting diagonal mass matrix based on the sample variance of the tuning samples. - * advi+adapt_diag_grad: Run ADVI and then adapt the resulting diagonal mass matrix based - on the variance of the gradients during tuning. This is **experimental** and might be - removed in a future release. * advi: Run ADVI to estimate posterior mean and diagonal mass matrix. * advi_map: Initialize ADVI with MAP and use MAP as starting point. * map: Use the MAP as starting point. This is discouraged. @@ -2174,24 +2158,24 @@ def init_nuts( var = np.ones_like(mean) n = len(var) potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10) - elif init == "advi+adapt_diag_grad": - approx: pm.MeanField = pm.fit( - random_seed=random_seed, - n=n_init, - method="advi", - model=model, - callbacks=cb, - progressbar=progressbar, - obj_optimizer=pm.adagrad_window, + elif init == "jitter+adapt_diag_grad": + start = _init_jitter(model, model.initial_point, chains, jitter_max_retries) + mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0) + var = np.ones_like(mean) + n = len(var) + + if tune is not None and tune > 250: + stop_adaptation = tune - 50 + else: + stop_adaptation = None + + potential = quadpotential.QuadPotentialDiagAdaptExp( + n, + mean, + alpha=0.02, + use_grads=True, + stop_adaptation=stop_adaptation, ) - start = approx.sample(draws=chains) - start = list(start) - std_apoint = approx.std.eval() - cov = std_apoint ** 2 - mean = approx.mean.get_value() - weight = 50 - n = len(cov) - potential = quadpotential.QuadPotentialDiagAdaptGrad(n, mean, cov, weight) elif init == "advi+adapt_diag": approx = pm.fit( random_seed=random_seed, diff --git a/pymc3/step_methods/hmc/nuts.py b/pymc3/step_methods/hmc/nuts.py index 267c20659fa..ef1375e9986 100644 --- a/pymc3/step_methods/hmc/nuts.py +++ b/pymc3/step_methods/hmc/nuts.py @@ -253,9 +253,7 @@ def __init__(self, ndim, integrator, start, step_size, Emax): self.start_energy = np.array(start.energy) self.left = self.right = start - self.proposal = Proposal( - start.q.data, start.q_grad.data, start.energy, 1.0, start.model_logp - ) + self.proposal = Proposal(start.q.data, start.q_grad, start.energy, 1.0, start.model_logp) self.depth = 0 self.log_size = 0 self.log_weighted_accept_sum = -np.inf @@ -350,7 +348,7 @@ def _single_step(self, left, epsilon): log_size = -energy_change proposal = Proposal( right.q.data, - right.q_grad.data, + right.q_grad, right.energy, log_p_accept_weighted, right.model_logp, diff --git a/pymc3/step_methods/hmc/quadpotential.py b/pymc3/step_methods/hmc/quadpotential.py index 2bb2f196c50..21c24907dbc 100644 --- a/pymc3/step_methods/hmc/quadpotential.py +++ b/pymc3/step_methods/hmc/quadpotential.py @@ -154,8 +154,40 @@ def __init__( adaptation_window=101, adaptation_window_multiplier=1, dtype=None, + discard_window=50, + early_update=False, + store_mass_matrix_trace=False, ): - """Set up a diagonal mass matrix.""" + """Set up a diagonal mass matrix. + + Parameters + ---------- + n : int + The number of parameters. + initial_mean : np.ndarray + An initial guess for the posterior mean of each parameter. + initial_diag : np.ndarray + An estimate of the posterior variance of each parameter. + initial_weight : int + How much weight the initial guess has compared to new samples during tuning. + Measured in equivalent number of samples. + adaptation_window : int + The size of the adaptation window during tuning. It specifies how many samples + are used to estimate the mass matrix in each section of the adaptation. + adaptation_window_multiplier : float + The factor with which we increase the adaptation window after each adaptation + window. + dtype : np.dtype + The dtype used to store the mass matrix + discard_window : int + The number of initial samples that are just discarded and not used to estimate + the mass matrix. + early_update : bool + Whether to update the mass matrix live during the first adaptation window. + store_mass_matrix_trace : bool + If true, store the mass matrix at each step of the adaptation. Only for debugging + purposes. + """ if initial_diag is not None and initial_diag.ndim != 1: raise ValueError("Initial diagonal must be one-dimensional.") if initial_mean.ndim != 1: @@ -175,12 +207,18 @@ def __init__( self.dtype = dtype self._n = n + self._discard_window = discard_window + self._early_update = early_update + self._initial_mean = initial_mean self._initial_diag = initial_diag self._initial_weight = initial_weight self.adaptation_window = adaptation_window self.adaptation_window_multiplier = float(adaptation_window_multiplier) + self._store_mass_matrix_trace = store_mass_matrix_trace + self._mass_trace = [] + self.reset() def reset(self): @@ -222,12 +260,18 @@ def _update_from_weightvar(self, weightvar): def update(self, sample, grad, tune): """Inform the potential about a new sample during tuning.""" + if self._store_mass_matrix_trace: + self._mass_trace.append(self._stds.copy()) + if not tune: return - self._foreground_var.add_sample(sample, weight=1) - self._background_var.add_sample(sample, weight=1) - self._update_from_weightvar(self._foreground_var) + if self._n_samples > self._discard_window: + self._foreground_var.add_sample(sample) + self._background_var.add_sample(sample) + + if self._early_update or self._n_samples > self.adaptation_window: + self._update_from_weightvar(self._foreground_var) if self._n_samples > 0 and self._n_samples % self.adaptation_window == 0: self._foreground_var = self._background_var @@ -275,47 +319,6 @@ def raise_ok(self, map_info): raise ValueError("\n".join(errmsg)) -class QuadPotentialDiagAdaptGrad(QuadPotentialDiagAdapt): - """Adapt a diagonal mass matrix from the variances of the gradients. - - This is experimental, and may be removed without prior deprication. - """ - - def reset(self): - super().reset() - self._grads1 = np.zeros(self._n, dtype=self.dtype) - self._ngrads1 = 0 - self._grads2 = np.zeros(self._n, dtype=self.dtype) - self._ngrads2 = 0 - - def _update(self, var): - self._var[:] = var - np.sqrt(self._var, out=self._stds) - np.divide(1, self._stds, out=self._inv_stds) - self._var_aesara.set_value(self._var) - - def update(self, sample, grad, tune): - """Inform the potential about a new sample during tuning.""" - if not tune: - return - - self._grads1[:] += np.abs(grad) - self._grads2[:] += np.abs(grad) - self._ngrads1 += 1 - self._ngrads2 += 1 - - if self._n_samples <= 150: - super().update(sample, grad, tune) - else: - self._update((self._ngrads1 / self._grads1) ** 2) - - if self._n_samples > 100 and self._n_samples % 100 == 50: - self._ngrads1 = self._ngrads2 - self._ngrads2 = 1 - self._grads1[:] = self._grads2 - self._grads2[:] = 1 - - class _WeightedVariance: """Online algorithm for computing mean of variance.""" @@ -340,13 +343,13 @@ def __init__( if self.mean.shape != (nelem,): raise ValueError("Invalid shape for initial mean.") - def add_sample(self, x, weight): + def add_sample(self, x): x = np.asarray(x) self.n_samples += 1 old_diff = x - self.mean self.mean[:] += old_diff / self.n_samples new_diff = x - self.mean - self.raw_var[:] += weight * old_diff * new_diff + self.raw_var[:] += old_diff * new_diff def current_variance(self, out=None): if self.n_samples == 0: @@ -360,6 +363,112 @@ def current_mean(self): return self.mean.copy(dtype=self._dtype) +class _ExpWeightedVariance: + def __init__(self, n_vars, *, init_mean, init_var, alpha): + self._variance = init_var + self._mean = init_mean + self._alpha = alpha + + def add_sample(self, value): + alpha = self._alpha + delta = value - self._mean + self._mean[...] += alpha * delta + self._variance[...] = (1 - alpha) * (self._variance + alpha * delta ** 2) + + def current_variance(self, out=None): + if out is None: + out = np.empty_like(self._variance) + np.copyto(out, self._variance) + return out + + def current_mean(self, out=None): + if out is None: + out = np.empty_like(self._mean) + np.copyto(out, self._mean) + return out + + +class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt): + def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, **kwargs): + """Set up a diagonal mass matrix. + + Parameters + ---------- + n : int + The number of parameters. + initial_mean : np.ndarray + An initial guess for the posterior mean of each parameter. + initial_diag : np.ndarray + An estimate of the posterior variance of each parameter. + alpha : float + Decay rate of the exponetial weighted variance. + use_grads : bool + Use gradients, not only samples to estimate the mass matrix. + stop_adaptation : int + Stop the mass matrix adaptation after this many samples. + dtype : np.dtype + The dtype used to store the mass matrix + discard_window : int + The number of initial samples that are just discarded and not used to estimate + the mass matrix. + store_mass_matrix_trace : bool + If true, store the mass matrix at each step of the adaptation. Only for debugging + purposes. + """ + if len(args) > 3: + raise ValueError("Unsupported arguments to QuadPotentialDiagAdaptExp") + + super().__init__(*args, **kwargs) + self._alpha = alpha + self._use_grads = use_grads + + if stop_adaptation is None: + stop_adaptation = np.inf + self._stop_adaptation = stop_adaptation + + def update(self, sample, grad, tune): + if tune and self._n_samples < self._stop_adaptation: + if self._n_samples > self._discard_window: + self._variance_estimator.add_sample(sample) + if self._use_grads: + self._variance_estimator_grad.add_sample(grad) + elif self._n_samples == self._discard_window: + self._variance_estimator = _ExpWeightedVariance( + self._n, + init_mean=sample.copy(), + init_var=np.zeros_like(sample), + alpha=self._alpha, + ) + if self._use_grads: + self._variance_estimator_grad = _ExpWeightedVariance( + self._n, + init_mean=grad.copy(), + init_var=np.zeros_like(grad), + alpha=self._alpha, + ) + + if self._n_samples > 2 * self._discard_window: + if self._use_grads: + self._update_from_variances( + self._variance_estimator, self._variance_estimator_grad + ) + else: + self._update_from_weightvar(self._variance_estimator) + + self._n_samples += 1 + + if self._store_mass_matrix_trace: + self._mass_trace.append(self._stds.copy()) + + def _update_from_variances(self, var_estimator, inv_var_estimator): + var = var_estimator.current_variance() + inv_var = inv_var_estimator.current_variance() + updated = np.sqrt(var / inv_var) + self._var[:] = updated + np.sqrt(updated, out=self._stds) + np.divide(1, self._stds, out=self._inv_stds) + + class QuadPotentialDiag(QuadPotential): """Quad potential using a diagonal covariance matrix.""" @@ -554,8 +663,8 @@ def update(self, sample, grad, tune): # Steps since previous update delta = self._n_samples - self._previous_update - self._foreground_cov.add_sample(sample, weight=1) - self._background_cov.add_sample(sample, weight=1) + self._foreground_cov.add_sample(sample) + self._background_cov.add_sample(sample) # Update the covariance matrix and recompute the Cholesky factorization # every "update_window" steps @@ -614,13 +723,13 @@ def __init__( if self.mean.shape != (nelem,): raise ValueError("Invalid shape for initial mean.") - def add_sample(self, x, weight): + def add_sample(self, x): x = np.asarray(x) self.n_samples += 1 old_diff = x - self.mean self.mean[:] += old_diff / self.n_samples new_diff = x - self.mean - self.raw_cov[:] += weight * new_diff[:, None] * old_diff[None, :] + self.raw_cov[:] += new_diff[:, None] * old_diff[None, :] def current_covariance(self, out=None): if self.n_samples == 0: diff --git a/pymc3/tests/test_quadpotential.py b/pymc3/tests/test_quadpotential.py index c869511a606..4bf342047b8 100644 --- a/pymc3/tests/test_quadpotential.py +++ b/pymc3/tests/test_quadpotential.py @@ -169,7 +169,7 @@ def test_weighted_covariance(ndim=10, seed=5432): est = quadpotential._WeightedCovariance(ndim) for sample in samples: - est.add_sample(sample, 1) + est.add_sample(sample) mu_est = est.current_mean() cov_est = est.current_covariance() @@ -184,7 +184,7 @@ def test_weighted_covariance(ndim=10, seed=5432): 10, ) for sample in samples[10:]: - est2.add_sample(sample, 1) + est2.add_sample(sample) mu_est2 = est2.current_mean() cov_est2 = est2.current_covariance() @@ -279,9 +279,3 @@ def test_full_adapt_sampling(seed=289586): pot = quadpotential.QuadPotentialFullAdapt(initial_point_size, np.zeros(initial_point_size)) step = pymc3.NUTS(model=model, potential=pot) pymc3.sample(draws=10, tune=1000, random_seed=seed, step=step, cores=1, chains=1) - - -def test_issue_3965(): - with pymc3.Model(): - pymc3.Normal("n") - pymc3.sample(100, tune=300, chains=1, init="advi+adapt_diag_grad") diff --git a/pymc3/tests/test_sampling.py b/pymc3/tests/test_sampling.py index 0b6ce99a15b..04767ae9b68 100644 --- a/pymc3/tests/test_sampling.py +++ b/pymc3/tests/test_sampling.py @@ -87,10 +87,19 @@ def test_sample(self): def test_sample_init(self): with self.model: - for init in ("advi", "advi_map", "map"): + for init in ( + "advi", + "advi_map", + "map", + "adapt_diag", + "jitter+adapt_diag", + "jitter+adapt_diag_grad", + "adapt_full", + "jitter+adapt_full", + ): pm.sample( init=init, - tune=0, + tune=120, n_init=1000, draws=50, random_seed=self.random_seed,