From 8560f1e16a96d2345b65892e77509145a95da5ea Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Wed, 1 Jul 2020 15:06:01 +0200 Subject: [PATCH] Honor discard_tuned_samples during KeyboardInterrupt (#3785) * Honor discard_tuned_samples during KeyboardInterrupt * Do not compute convergence checks without samples --- pymc3/backends/report.py | 9 ++++++++- pymc3/sampling.py | 7 ++++++- 2 files changed, 14 insertions(+), 2 deletions(-) diff --git a/pymc3/backends/report.py b/pymc3/backends/report.py index 1e506a836da..4384b85cbf5 100644 --- a/pymc3/backends/report.py +++ b/pymc3/backends/report.py @@ -99,7 +99,14 @@ def raise_ok(self, level='error'): if errors: raise ValueError('Serious convergence issues during sampling.') - def _run_convergence_checks(self, idata:arviz.InferenceData, model): + def _run_convergence_checks(self, idata: arviz.InferenceData, model): + if not hasattr(idata, 'posterior'): + msg = "No posterior samples. Unable to run convergence checks" + warn = SamplerWarning(WarningType.BAD_PARAMS, msg, 'info', + None, None, None) + self._add_warnings([warn]) + return + if idata.posterior.sizes['chain'] == 1: msg = ("Only one chain was sampled, this makes it impossible to " "run some convergence checks") diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 4b76cb8184b..5b381403b0b 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -502,6 +502,7 @@ def sample( "random_seed": random_seed, "cores": cores, "callback": callback, + "discard_tuned_samples": discard_tuned_samples, } sample_args.update(kwargs) @@ -1347,6 +1348,7 @@ def _mp_sample( trace=None, model=None, callback=None, + discard_tuned_samples=True, **kwargs ): """Main iteration for multiprocess sampling. @@ -1439,7 +1441,10 @@ def _mp_sample( raise return MultiTrace(traces) except KeyboardInterrupt: - traces, length = _choose_chains(traces, tune) + if discard_tuned_samples: + traces, length = _choose_chains(traces, tune) + else: + traces, length = _choose_chains(traces, 0) return MultiTrace(traces)[:length] finally: for trace in traces: