Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Improve tuning by skipping the first samples + add new experimental tuning method #5004

Merged
merged 12 commits into from
Sep 22, 2021
20 changes: 20 additions & 0 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -516,6 +516,7 @@ def sample(
random_seed=random_seed,
progressbar=progressbar,
jitter_max_retries=jitter_max_retries,
tune=tune,
**kwargs,
)
if start is None:
Expand Down Expand Up @@ -2078,6 +2079,7 @@ def init_nuts(
random_seed=None,
progressbar=True,
jitter_max_retries=10,
tune=None,
**kwargs,
):
"""Set up the mass matrix initialization for NUTS.
Expand Down Expand Up @@ -2174,6 +2176,24 @@ def init_nuts(
var = np.ones_like(mean)
n = len(var)
potential = quadpotential.QuadPotentialDiagAdapt(n, mean, var, 10)
elif init == "jitter+adapt_diag_grad":
start = _init_jitter(model, model.initial_point, chains, jitter_max_retries)
mean = np.mean([DictToArrayBijection.map(vals).data for vals in start], axis=0)
var = np.ones_like(mean)
n = len(var)

if tune is not None and tune > 200:
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved
aseyboldt marked this conversation as resolved.
Show resolved Hide resolved
stop_adaptation = tune - 50
else:
stop_adaptation = None

potential = quadpotential.QuadPotentialDiagAdaptExp(
n,
mean,
alpha=0.02,
use_grads=True,
stop_adaptation=stop_adaptation,
)
elif init == "advi+adapt_diag_grad":
approx: pm.MeanField = pm.fit(
random_seed=random_seed,
Expand Down
103 changes: 100 additions & 3 deletions pymc3/step_methods/hmc/quadpotential.py
Original file line number Diff line number Diff line change
Expand Up @@ -154,6 +154,10 @@ def __init__(
adaptation_window=101,
adaptation_window_multiplier=1,
dtype=None,
discard_window=50,
aseyboldt marked this conversation as resolved.
Show resolved Hide resolved
initial_weights=None,
early_update=False,
store_mass_matrix_trace=False,
):
"""Set up a diagonal mass matrix."""
if initial_diag is not None and initial_diag.ndim != 1:
Expand All @@ -175,12 +179,20 @@ def __init__(
self.dtype = dtype
self._n = n

self._discard_window = discard_window
self._early_update = early_update

self._initial_mean = initial_mean
self._initial_diag = initial_diag
self._initial_weight = initial_weight
self.adaptation_window = adaptation_window
self.adaptation_window_multiplier = float(adaptation_window_multiplier)

self._store_mass_matrix_trace = store_mass_matrix_trace
self._mass_trace = []

self._initial_weights = initial_weights

self.reset()

def reset(self):
Expand Down Expand Up @@ -222,12 +234,18 @@ def _update_from_weightvar(self, weightvar):

def update(self, sample, grad, tune):
"""Inform the potential about a new sample during tuning."""
if self._store_mass_matrix_trace:
self._mass_trace.append(self._stds.copy())
ricardoV94 marked this conversation as resolved.
Show resolved Hide resolved

if not tune:
return

self._foreground_var.add_sample(sample, weight=1)
self._background_var.add_sample(sample, weight=1)
self._update_from_weightvar(self._foreground_var)
if self._n_samples > self._discard_window:
self._foreground_var.add_sample(sample, weight=1)
self._background_var.add_sample(sample, weight=1)

if self._early_update or self._n_samples > self.adaptation_window:
self._update_from_weightvar(self._foreground_var)

if self._n_samples > 0 and self._n_samples % self.adaptation_window == 0:
self._foreground_var = self._background_var
Expand Down Expand Up @@ -342,6 +360,8 @@ def __init__(

def add_sample(self, x, weight):
x = np.asarray(x)
if weight != 1:
raise ValueError("weight is unused and broken")
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Suggested change
raise ValueError("weight is unused and broken")
raise ValueError("Setting weight != 1 is not supported.")

Or maybe we should just remove it all-together.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

done

self.n_samples += 1
old_diff = x - self.mean
self.mean[:] += old_diff / self.n_samples
Expand All @@ -360,6 +380,83 @@ def current_mean(self):
return self.mean.copy(dtype=self._dtype)


class _ExpWeightedVariance:
def __init__(self, n_vars, *, init_mean, init_var, alpha):
self._variance = init_var
self._mean = init_mean
self._alpha = alpha

def add_sample(self, value):
alpha = self._alpha
delta = value - self._mean
self._mean += alpha * delta
self._variance = (1 - alpha) * (self._variance + alpha * delta ** 2)

def current_variance(self, out=None):
if out is None:
out = np.empty_like(self._variance)
np.copyto(out, self._variance)
return out

def current_mean(self, out=None):
if out is None:
out = np.empty_like(self._mean)
np.copyto(out, self._mean)
return out


class QuadPotentialDiagAdaptExp(QuadPotentialDiagAdapt):
def __init__(self, *args, alpha, use_grads=False, stop_adaptation=None, **kwargs):
super().__init__(*args, **kwargs)
self._alpha = alpha
self._use_grads = use_grads
self._stop_adaptation = stop_adaptation

def update(self, sample, grad, tune):
if tune and self._n_samples < self._stop_adaptation:
if self._n_samples > self._discard_window:
self._variance_estimator.add_sample(sample)
if self._use_grads:
self._variance_estimator_grad.add_sample(grad)
elif self._n_samples == self._discard_window:
self._variance_estimator = _ExpWeightedVariance(
self._n,
init_mean=sample.copy(),
init_var=np.zeros_like(sample),
alpha=self._alpha,
)
if self._use_grads:
self._variance_estimator_grad = _ExpWeightedVariance(
self._n,
init_mean=grad.copy(),
init_var=np.zeros_like(grad),
alpha=self._alpha,
)

if self._n_samples > 2 * self._discard_window:
if self._use_grads:
self._update_from_variances(
self._variance_estimator, self._variance_estimator_grad
)
else:
self._update_from_weightvar(self._variance_estimator)

self._n_samples += 1

if self._store_mass_matrix_trace:
self._mass_trace.append(self._stds.copy())

def _update_from_variances(self, var_estimator, inv_var_estimator):
var = var_estimator.current_variance()
inv_var = inv_var_estimator.current_variance()
# print(inv_var)
updated = np.sqrt(var / inv_var)
self._var[:] = updated
# updated = np.exp((np.log(var) - np.log(inv_var)) / 2)
np.sqrt(updated, out=self._stds)
np.divide(1, self._stds, out=self._inv_stds)


class QuadPotentialDiag(QuadPotential):
"""Quad potential using a diagonal covariance matrix."""

Expand Down