diff --git a/pymc3/sampling.py b/pymc3/sampling.py index c8fe6708ac7..a5d567cbf9c 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -470,7 +470,7 @@ def sample( _log.info("Multiprocess sampling ({} chains in {} jobs)".format(chains, cores)) _print_step_hierarchy(step) try: - trace = _mp_sample(**sample_args) + trace = _mp_sample(**sample_args, discard_tuned_samples=discard_tuned_samples) except pickle.PickleError: _log.warning("Could not pickle model, sampling singlethreaded.") _log.debug("Pickling error:", exec_info=True) @@ -1243,6 +1243,7 @@ def _mp_sample( progressbar=True, trace=None, model=None, + discard_tuned_samples=True, **kwargs ): """Main iteration for multiprocess sampling. @@ -1325,7 +1326,10 @@ def _mp_sample( raise return MultiTrace(traces) except KeyboardInterrupt: - traces, length = _choose_chains(traces, tune) + if discard_tuned_samples: + traces, length = _choose_chains(traces, tune) + else: + traces, length = _choose_chains(traces, 0) return MultiTrace(traces)[:length] finally: for trace in traces: