From 90f48ed22e3c28652903fa27aaaeff59aa29acdf Mon Sep 17 00:00:00 2001 From: Adrian Seyboldt Date: Tue, 7 Jul 2020 20:45:10 +0200 Subject: [PATCH] Show pickling issues in notebook on windows (#3991) * Merge close remote connection * Manually pickle step method in multiprocess sampling * Fix tests for extra divergence info * Add test for remote process crash * Better formatting in test_parallel_sampling Co-authored-by: Junpeng Lao * Use mp_ctx forkserver on MacOS * Add test for pickle with dill Co-authored-by: Junpeng Lao --- RELEASE-NOTES.md | 5 + pymc3/__init__.py | 6 - pymc3/parallel_sampling.py | 220 ++++++++++++++++++-------- pymc3/sampling.py | 31 +++- pymc3/tests/test_parallel_sampling.py | 74 ++++++++- requirements-dev.txt | 1 + 6 files changed, 254 insertions(+), 83 deletions(-) diff --git a/RELEASE-NOTES.md b/RELEASE-NOTES.md index c678272efbc..9ece22e92c6 100644 --- a/RELEASE-NOTES.md +++ b/RELEASE-NOTES.md @@ -1,10 +1,15 @@ # Release Notes ## PyMC3 3.9.x (on deck) + +### Maintenance +- Fix an error on Windows and Mac where error message from unpickling models did not show up in the notebook, or where sampling froze when a worker process crashed (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991)). + ### Documentation - Notebook on [multilevel modeling](https://docs.pymc.io/notebooks/multilevel_modeling.html) has been rewritten to showcase ArviZ and xarray usage for inference result analysis (see [#3963](https://github.com/pymc-devs/pymc3/pull/3963)) ### New features +- Introduce optional arguments to `pm.sample`: `mp_ctx` to control how the processes for parallel sampling are started, and `pickle_backend` to specify which library is used to pickle models in parallel sampling when the multiprocessing cnotext is not of type `fork`. (see [#3991](https://github.com/pymc-devs/pymc3/pull/3991)) - Add sampler stats `process_time_diff`, `perf_counter_diff` and `perf_counter_start`, that record wall and CPU times for each NUTS and HMC sample (see [ #3986](https://github.com/pymc-devs/pymc3/pull/3986)). ## PyMC3 3.9.2 (24 June 2020) diff --git a/pymc3/__init__.py b/pymc3/__init__.py index 21ec7f3a591..ad47bbbfe28 100644 --- a/pymc3/__init__.py +++ b/pymc3/__init__.py @@ -27,12 +27,6 @@ handler = logging.StreamHandler() _log.addHandler(handler) -# Set start method to forkserver for MacOS to enable multiprocessing -# Closes issue https://github.com/pymc-devs/pymc3/issues/3849 -sys = platform.system() -if sys == "Darwin": - new_context = mp.get_context("forkserver") - def __set_compiler_flags(): # Workarounds for Theano compiler problems on various platforms diff --git a/pymc3/parallel_sampling.py b/pymc3/parallel_sampling.py index 3caa4ff543d..fb4b464b01c 100644 --- a/pymc3/parallel_sampling.py +++ b/pymc3/parallel_sampling.py @@ -17,10 +17,11 @@ import ctypes import time import logging +import pickle from collections import namedtuple import traceback +import platform from pymc3.exceptions import SamplingError -import errno import numpy as np from fastprogress.fastprogress import progress_bar @@ -30,37 +31,6 @@ logger = logging.getLogger("pymc3") -def _get_broken_pipe_exception(): - import sys - - if sys.platform == "win32": - return RuntimeError( - "The communication pipe between the main process " - "and its spawned children is broken.\n" - "In Windows OS, this usually means that the child " - "process raised an exception while it was being " - "spawned, before it was setup to communicate to " - "the main process.\n" - "The exceptions raised by the child process while " - "spawning cannot be caught or handled from the " - "main process, and when running from an IPython or " - "jupyter notebook interactive kernel, the child's " - "exception and traceback appears to be lost.\n" - "A known way to see the child's error, and try to " - "fix or handle it, is to run the problematic code " - "as a batch script from a system's Command Prompt. " - "The child's exception will be printed to the " - "Command Promt's stderr, and it should be visible " - "above this error and traceback.\n" - "Note that if running a jupyter notebook that was " - "invoked from a Command Prompt, the child's " - "exception should have been printed to the Command " - "Prompt on which the notebook is running." - ) - else: - return None - - class ParallelSamplingError(Exception): def __init__(self, message, chain, warnings=None): super().__init__(message) @@ -104,26 +74,65 @@ def rebuild_exc(exc, tb): # ('start',) -class _Process(multiprocessing.Process): +class _Process: """Seperate process for each chain. We communicate with the main process using a pipe, and send finished samples using shared memory. """ - def __init__(self, name:str, msg_pipe, step_method, shared_point, draws:int, tune:int, seed): - super().__init__(daemon=True, name=name) + def __init__( + self, + name: str, + msg_pipe, + step_method, + step_method_is_pickled, + shared_point, + draws: int, + tune: int, + seed, + pickle_backend, + ): self._msg_pipe = msg_pipe self._step_method = step_method + self._step_method_is_pickled = step_method_is_pickled self._shared_point = shared_point self._seed = seed self._tt_seed = seed + 1 self._draws = draws self._tune = tune + self._pickle_backend = pickle_backend + + def _unpickle_step_method(self): + unpickle_error = ( + "The model could not be unpickled. This is required for sampling " + "with more than one core and multiprocessing context spawn " + "or forkserver." + ) + if self._step_method_is_pickled: + if self._pickle_backend == 'pickle': + try: + self._step_method = pickle.loads(self._step_method) + except Exception: + raise ValueError(unpickle_error) + elif self._pickle_backend == 'dill': + try: + import dill + except ImportError: + raise ValueError( + "dill must be installed for pickle_backend='dill'." + ) + try: + self._step_method = dill.loads(self._step_method) + except Exception: + raise ValueError(unpickle_error) + else: + raise ValueError("Unknown pickle backend") def run(self): try: # We do not create this in __init__, as pickling this # would destroy the shared memory. + self._unpickle_step_method() self._point = self._make_numpy_refs() self._start_loop() except KeyboardInterrupt: @@ -219,10 +228,25 @@ def _collect_warnings(self): return [] +def _run_process(*args): + _Process(*args).run() + + class ProcessAdapter: """Control a Chain process from the main thread.""" - def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start): + def __init__( + self, + draws: int, + tune: int, + step_method, + step_method_pickled, + chain: int, + seed, + start, + mp_ctx, + pickle_backend, + ): self.chain = chain process_name = "worker_chain_%s" % chain self._msg_pipe, remote_conn = multiprocessing.Pipe() @@ -237,7 +261,7 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start): if size != ctypes.c_size_t(size).value: raise ValueError("Variable %s is too large" % name) - array = multiprocessing.sharedctypes.RawArray("c", size) + array = mp_ctx.RawArray("c", size) self._shared_point[name] = array array_np = np.frombuffer(array, dtype).reshape(shape) array_np[...] = start[name] @@ -246,27 +270,31 @@ def __init__(self, draws:int, tune:int, step_method, chain:int, seed, start): self._readable = True self._num_samples = 0 - self._process = _Process( - process_name, - remote_conn, - step_method, - self._shared_point, - draws, - tune, - seed, + if step_method_pickled is not None: + step_method_send = step_method_pickled + else: + step_method_send = step_method + + self._process = mp_ctx.Process( + daemon=True, + name=process_name, + target=_run_process, + args=( + process_name, + remote_conn, + step_method_send, + step_method_pickled is not None, + self._shared_point, + draws, + tune, + seed, + pickle_backend, + ) ) - try: - self._process.start() - except IOError as e: - # Something may have gone wrong during the fork / spawn - if e.errno == errno.EPIPE: - exc = _get_broken_pipe_exception() - if exc is not None: - # Sleep a little to give the child process time to flush - # all its error message - time.sleep(0.2) - raise exc - raise + self._process.start() + # Close the remote pipe, so that we get notified if the other + # end is closed. + remote_conn.close() @property def shared_point_view(self): @@ -277,15 +305,38 @@ def shared_point_view(self): raise RuntimeError() return self._point + def _send(self, msg, *args): + try: + self._msg_pipe.send((msg, *args)) + except Exception: + # try to recive an error message + message = None + try: + message = self._msg_pipe.recv() + except Exception: + pass + if message is not None and message[0] == "error": + warns, old_error = message[1:] + if warns is not None: + error = ParallelSamplingError( + str(old_error), + self.chain, + warns + ) + else: + error = RuntimeError("Chain %s failed." % self.chain) + raise error from old_error + raise + def start(self): - self._msg_pipe.send(("start",)) + self._send("start") def write_next(self): self._readable = False - self._msg_pipe.send(("write_next",)) + self._send("write_next") def abort(self): - self._msg_pipe.send(("abort",)) + self._send("abort") def join(self, timeout=None): self._process.join(timeout) @@ -324,7 +375,7 @@ def terminate_all(processes, patience=2): for process in processes: try: process.abort() - except EOFError: + except Exception: pass start_time = time.time() @@ -353,23 +404,52 @@ def terminate_all(processes, patience=2): class ParallelSampler: def __init__( self, - draws:int, - tune:int, - chains:int, - cores:int, - seeds:list, - start_points:list, + draws: int, + tune: int, + chains: int, + cores: int, + seeds: list, + start_points: list, step_method, - start_chain_num:int=0, - progressbar:bool=True, + start_chain_num: int = 0, + progressbar: bool = True, + mp_ctx=None, + pickle_backend: str = 'pickle', ): if any(len(arg) != chains for arg in [seeds, start_points]): raise ValueError("Number of seeds and start_points must be %s." % chains) + if mp_ctx is None or isinstance(mp_ctx, str): + # Closes issue https://github.com/pymc-devs/pymc3/issues/3849 + if platform.system() == 'Darwin': + mp_ctx = "forkserver" + mp_ctx = multiprocessing.get_context(mp_ctx) + + step_method_pickled = None + if mp_ctx.get_start_method() != 'fork': + if pickle_backend == 'pickle': + step_method_pickled = pickle.dumps(step_method, protocol=-1) + elif pickle_backend == 'dill': + try: + import dill + except ImportError: + raise ValueError( + "dill must be installed for pickle_backend='dill'." + ) + step_method_pickled = dill.dumps(step_method, protocol=-1) + self._samplers = [ ProcessAdapter( - draws, tune, step_method, chain + start_chain_num, seed, start + draws, + tune, + step_method, + step_method_pickled, + chain + start_chain_num, + seed, + start, + mp_ctx, + pickle_backend ) for chain, seed, start in zip(range(chains), seeds, start_points) ] diff --git a/pymc3/sampling.py b/pymc3/sampling.py index 5b381403b0b..f550218302a 100644 --- a/pymc3/sampling.py +++ b/pymc3/sampling.py @@ -248,7 +248,9 @@ def sample( callback=None, *, return_inferencedata=None, - idata_kwargs:dict=None, + idata_kwargs: dict=None, + mp_ctx=None, + pickle_backend: str = 'pickle', **kwargs ): """Draw samples from the posterior using the given step methods. @@ -336,6 +338,13 @@ def sample( Defaults to `False`, but we'll switch to `True` in an upcoming release. idata_kwargs : dict, optional Keyword arguments for `arviz.from_pymc3` + mp_ctx : multiprocessing.context.BaseContent + A multiprocessing context for parallel sampling. See multiprocessing + documentation for details. + pickle_backend : str + One of `'pickle'` or `'dill'`. The library used to pickle models + in parallel sampling if the multiprocessing context is not of type + `fork`. Returns ------- @@ -504,6 +513,10 @@ def sample( "callback": callback, "discard_tuned_samples": discard_tuned_samples, } + parallel_args = { + "pickle_backend": pickle_backend, + "mp_ctx": mp_ctx, + } sample_args.update(kwargs) @@ -520,7 +533,7 @@ def sample( _log.info("Multiprocess sampling ({} chains in {} jobs)".format(chains, cores)) _print_step_hierarchy(step) try: - trace = _mp_sample(**sample_args) + trace = _mp_sample(**sample_args, **parallel_args) except pickle.PickleError: _log.warning("Could not pickle model, sampling singlethreaded.") _log.debug("Pickling error:", exec_info=True) @@ -1349,6 +1362,8 @@ def _mp_sample( model=None, callback=None, discard_tuned_samples=True, + mp_ctx=None, + pickle_backend='pickle', **kwargs ): """Main iteration for multiprocess sampling. @@ -1411,7 +1426,17 @@ def _mp_sample( traces.append(strace) sampler = ps.ParallelSampler( - draws, tune, chains, cores, random_seed, start, step, chain, progressbar + draws, + tune, + chains, + cores, + random_seed, + start, + step, + chain, + progressbar, + mp_ctx=mp_ctx, + pickle_backend=pickle_backend, ) try: try: diff --git a/pymc3/tests/test_parallel_sampling.py b/pymc3/tests/test_parallel_sampling.py index b6d8fc5472a..e61007ac8d0 100644 --- a/pymc3/tests/test_parallel_sampling.py +++ b/pymc3/tests/test_parallel_sampling.py @@ -11,9 +11,71 @@ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. +import multiprocessing +import os +import pytest import pymc3.parallel_sampling as ps import pymc3 as pm +import theano +import theano.tensor as tt +import numpy as np + + +def test_context(): + with pm.Model(): + pm.Normal('x') + ctx = multiprocessing.get_context('spawn') + pm.sample(tune=2, draws=2, chains=2, cores=2, mp_ctx=ctx) + + +class NoUnpickle: + def __getstate__(self): + return self.__dict__.copy() + + def __setstate__(self, state): + raise AttributeError("This fails") + + +def test_bad_unpickle(): + with pm.Model() as model: + pm.Normal('x') + + with model: + step = pm.NUTS() + step.no_unpickle = NoUnpickle() + with pytest.raises(Exception) as exc_info: + pm.sample(tune=2, draws=2, mp_ctx='spawn', step=step, + cores=2, chains=2, compute_convergence_checks=False) + assert 'could not be unpickled' in str(exc_info.getrepr(style='short')) + + +tt_vector = tt.TensorType(theano.config.floatX, [False]) + + +@theano.as_op([tt_vector, tt.iscalar], [tt_vector]) +def _crash_remote_process(a, master_pid): + if os.getpid() != master_pid: + os.exit(0) + return 2 * np.array(a) + + +def test_dill(): + with pm.Model(): + pm.Normal('x') + pm.sample(tune=1, draws=1, chains=2, cores=2, pickle_backend="dill", mp_ctx="spawn") + + +def test_remote_pipe_closed(): + master_pid = os.getpid() + with pm.Model(): + x = pm.Normal('x', shape=2, mu=0.1) + tt_pid = tt.as_tensor_variable(np.array(master_pid, dtype='int32')) + pm.Normal('y', mu=_crash_remote_process(x, tt_pid), shape=2) + + step = pm.Metropolis() + with pytest.raises(RuntimeError, match="Chain [0-9] failed"): + pm.sample(step=step, mp_ctx='spawn', tune=2, draws=2, cores=2, chains=2) def test_abort(): @@ -25,8 +87,10 @@ def test_abort(): step = pm.CompoundStep([step1, step2]) - proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, - start={'a': 1., 'b_log__': 2.}) + ctx = multiprocessing.get_context() + proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, mp_ctx=ctx, + start={'a': 1., 'b_log__': 2.}, + step_method_pickled=None, pickle_backend='pickle') proc.start() proc.write_next() proc.abort() @@ -42,8 +106,10 @@ def test_explicit_sample(): step = pm.CompoundStep([step1, step2]) - proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, - start={'a': 1., 'b_log__': 2.}) + ctx = multiprocessing.get_context() + proc = ps.ProcessAdapter(10, 10, step, chain=3, seed=1, mp_ctx=ctx, + start={'a': 1., 'b_log__': 2.}, + step_method_pickled=None, pickle_backend='pickle') proc.start() while True: proc.write_next() diff --git a/requirements-dev.txt b/requirements-dev.txt index 7544a31085a..4c3d2a2859c 100644 --- a/requirements-dev.txt +++ b/requirements-dev.txt @@ -18,3 +18,4 @@ seaborn>=0.8.1 sphinx-autobuild==0.7.1 sphinx>=1.5.5 watermark +dill