Skip to content

Commit

Permalink
Add more info to divergence warnings
Browse files Browse the repository at this point in the history
  • Loading branch information
aseyboldt committed Jul 1, 2020
1 parent 747db63 commit dbc2196
Show file tree
Hide file tree
Showing 4 changed files with 62 additions and 35 deletions.
50 changes: 30 additions & 20 deletions pymc3/backends/report.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,10 +12,11 @@
# See the License for the specific language governing permissions and
# limitations under the License.

from collections import namedtuple
import logging
import enum
import typing
from typing import Any, Optional
import dataclasses

from ..util import is_transformed_name, get_untransformed_name

import arviz
Expand All @@ -38,9 +39,17 @@ class WarningType(enum.Enum):
BAD_ENERGY = 8


SamplerWarning = namedtuple(
'SamplerWarning',
"kind, message, level, step, exec_info, extra")
@dataclasses.dataclass
class SamplerWarning:
kind: WarningType
message: str
level: str
step: Optional[int] = None
exec_info: Optional[Any] = None
extra: Optional[Any] = None
divergence_point_source: Optional[dict] = None
divergence_point_dest: Optional[dict] = None
divergence_info: Optional[Any] = None


_LEVELS = {
Expand All @@ -53,7 +62,8 @@ class WarningType(enum.Enum):


class SamplerReport:
"""This object bundles warnings, convergence statistics and metadata of a sampling run."""
"""Bundle warnings, convergence stats and metadata of a sampling run."""

def __init__(self):
self._chain_warnings = {}
self._global_warnings = []
Expand All @@ -75,17 +85,17 @@ def ok(self):
for warn in self._warnings)

@property
def n_tune(self) -> typing.Optional[int]:
def n_tune(self) -> Optional[int]:
"""Number of tune iterations - not necessarily kept in trace!"""
return self._n_tune

@property
def n_draws(self) -> typing.Optional[int]:
def n_draws(self) -> Optional[int]:
"""Number of draw iterations."""
return self._n_draws

@property
def t_sampling(self) -> typing.Optional[float]:
def t_sampling(self) -> Optional[float]:
"""
Number of seconds that the sampling procedure took.
Expand All @@ -99,12 +109,11 @@ def raise_ok(self, level='error'):
if errors:
raise ValueError('Serious convergence issues during sampling.')

def _run_convergence_checks(self, idata:arviz.InferenceData, model):
def _run_convergence_checks(self, idata: arviz.InferenceData, model):
if idata.posterior.sizes['chain'] == 1:
msg = ("Only one chain was sampled, this makes it impossible to "
"run some convergence checks")
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info',
None, None, None)
warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info')
self._add_warnings([warn])
return

Expand All @@ -127,41 +136,42 @@ def _run_convergence_checks(self, idata:arviz.InferenceData, model):
msg = ("The rhat statistic is larger than 1.4 for some "
"parameters. The sampler did not converge.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'error', None, None, rhat)
WarningType.CONVERGENCE, msg, 'error', extra=rhat)
warnings.append(warn)
elif rhat_max > 1.2:
msg = ("The rhat statistic is larger than 1.2 for some "
"parameters.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'warn', None, None, rhat)
WarningType.CONVERGENCE, msg, 'warn', extra=rhat)
warnings.append(warn)
elif rhat_max > 1.05:
msg = ("The rhat statistic is larger than 1.05 for some "
"parameters. This indicates slight problems during "
"sampling.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'info', None, None, rhat)
WarningType.CONVERGENCE, msg, 'info', extra=rhat)
warnings.append(warn)

eff_min = min(val.min() for val in ess.values())
n_samples = idata.posterior.sizes['chain'] * idata.posterior.sizes['draw']
sizes = idata.posterior.sizes
n_samples = sizes['chain'] * sizes['draw']
if eff_min < 200 and n_samples >= 500:
msg = ("The estimated number of effective samples is smaller than "
"200 for some parameters.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'error', None, None, ess)
WarningType.CONVERGENCE, msg, 'error', extra=ess)
warnings.append(warn)
elif eff_min / n_samples < 0.1:
msg = ("The number of effective samples is smaller than "
"10% for some parameters.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'warn', None, None, ess)
WarningType.CONVERGENCE, msg, 'warn', extra=ess)
warnings.append(warn)
elif eff_min / n_samples < 0.25:
msg = ("The number of effective samples is smaller than "
"25% for some parameters.")
warn = SamplerWarning(
WarningType.CONVERGENCE, msg, 'info', None, None, ess)
WarningType.CONVERGENCE, msg, 'info', extra=ess)
warnings.append(warn)

self._add_warnings(warnings)
Expand Down Expand Up @@ -194,7 +204,7 @@ def filter_warns(warnings):
filtered.append(warn)
elif (start <= warn.step < stop and
(warn.step - start) % step == 0):
warn = warn._replace(step=warn.step - start)
warn = dataclasses.replace(warn, step=warn.step - start)
filtered.append(warn)
return filtered

Expand Down
40 changes: 28 additions & 12 deletions pymc3/step_methods/hmc/base_hmc.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,10 +28,16 @@

logger = logging.getLogger("pymc3")

HMCStepData = namedtuple("HMCStepData", "end, accept_stat, divergence_info, stats")
HMCStepData = namedtuple(
"HMCStepData",
"end, accept_stat, divergence_info, stats"
)

DivergenceInfo = namedtuple(
"DivergenceInfo",
"message, exec_info, state, state_div"
)

DivergenceInfo = namedtuple("DivergenceInfo", "message, exec_info, state")

class BaseHMC(arraystep.GradientSharedStep):
"""Superclass to implement Hamiltonian/hybrid monte carlo."""
Expand Down Expand Up @@ -151,8 +157,6 @@ def astep(self, q0):
message_energy,
"critical",
self.iter_count,
None,
None,
)
self._warnings.append(warning)
raise SamplingError("Bad initial energy")
Expand All @@ -170,19 +174,30 @@ def astep(self, q0):
self.potential.update(hmc_step.end.q, hmc_step.end.q_grad, self.tune)
if hmc_step.divergence_info:
info = hmc_step.divergence_info
point = None
point_dest = None
info_store = None
if self.tune:
kind = WarningType.TUNING_DIVERGENCE
point = None
else:
kind = WarningType.DIVERGENCE
self._num_divs_sample += 1
# We don't want to fill up all memory with divergence info
if self._num_divs_sample < 100:
point = self._logp_dlogp_func.array_to_dict(info.state.q)
else:
point = None
point_dest = self._logp_dlogp_func.array_to_dict(
info.state_div.q
)
info_store = info
warning = SamplerWarning(
kind, info.message, "debug", self.iter_count, info.exec_info, point
kind,
info.message,
"debug",
self.iter_count,
info.exec_info,
divergence_point_source=point,
divergence_point_dest=point_dest,
divergence_info=info_store,
)

self._warnings.append(warning)
Expand All @@ -191,7 +206,10 @@ def astep(self, q0):
if not self.tune:
self._samples_after_tune += 1

stats = {"tune": self.tune, "diverging": bool(hmc_step.divergence_info)}
stats = {
"tune": self.tune,
"diverging": bool(hmc_step.divergence_info)
}

stats.update(hmc_step.stats)
stats.update(self.step_adapt.stats())
Expand Down Expand Up @@ -230,9 +248,7 @@ def warnings(self):
)

if message:
warning = SamplerWarning(
WarningType.DIVERGENCES, message, "error", None, None, None
)
warning = SamplerWarning(WarningType.DIVERGENCES, message, "error")
warnings.append(warning)

warnings.extend(self.step_adapt.warnings())
Expand Down
5 changes: 3 additions & 2 deletions pymc3/step_methods/hmc/nuts.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,7 @@ def warnings(self):
"The chain reached the maximum tree depth. Increase "
"max_treedepth, increase target_accept or reparameterize."
)
warn = SamplerWarning(WarningType.TREEDEPTH, msg, "warn", None, None, None)
warn = SamplerWarning(WarningType.TREEDEPTH, msg, 'warn')
warnings.append(warn)
return warnings

Expand Down Expand Up @@ -321,6 +321,7 @@ def _single_step(self, left, epsilon):
except IntegrationError as err:
error_msg = str(err)
error = err
right = None
else:
# h - H0
energy_change = right.energy - self.start_energy
Expand Down Expand Up @@ -353,7 +354,7 @@ def _single_step(self, left, epsilon):
)
error = None
tree = Subtree(None, None, None, None, -np.inf, -np.inf, 1)
divergance_info = DivergenceInfo(error_msg, error, left)
divergance_info = DivergenceInfo(error_msg, error, left, right)
return tree, divergance_info, False

def _build_subtree(self, left, depth, epsilon):
Expand Down
2 changes: 1 addition & 1 deletion pymc3/step_methods/step_sizes.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ def warnings(self):
% (mean_accept, target_accept))
info = {'target': target_accept, 'actual': mean_accept}
warning = SamplerWarning(
WarningType.BAD_ACCEPTANCE, msg, 'warn', None, None, info)
WarningType.BAD_ACCEPTANCE, msg, 'warn', extra=info)
return [warning]
else:
return []

0 comments on commit dbc2196

Please sign in to comment.