Skip to content

Commit

Permalink
switch from pickle/dill to cloudpickle (pymc-devs#4858)
Browse files Browse the repository at this point in the history
* use cloudpickle for serialization

* add cloudpickle to requirements

* update tests for cloudpickle

* update release notes with cloudpickle

* update conda envs with cloudpickle

* remove special case serialization for DensityDist.logp

* add pickle import back in for pickle.PickleError

* remove strict error message check in test
  • Loading branch information
Spaak authored and ricardoV94 committed Jul 15, 2021
1 parent 48f7558 commit 4f0ce74
Show file tree
Hide file tree
Showing 17 changed files with 42 additions and 105 deletions.
1 change: 1 addition & 0 deletions RELEASE-NOTES.md
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
- Logp method of `Uniform` and `DiscreteUniform` no longer depends on `pymc3.distributions.dist_math.bound` for proper evaluation (see [#4541](https://github.com/pymc-devs/pymc3/pull/4541)).
- `Model.RV_dims` and `Model.coords` are now read-only properties. To modify the `coords` dictionary use `Model.add_coord`. Also `dims` or coordinate values that are `None` will be auto-completed (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)).
- The length of `dims` in the model is now tracked symbolically through `Model.dim_lengths` (see [#4625](https://github.com/pymc-devs/pymc3/pull/4625)).
- We now include `cloudpickle` as a required dependency, and no longer depend on `dill` (see [#4858](https://github.com/pymc-devs/pymc3/pull/4858)).
- ...

## PyMC3 3.11.2 (14 March 2021)
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py37.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- aesara>=2.0.9
- arviz>=0.11.2
- cachetools>=4.2.1
- dill
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- ipython>=7.16
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- aesara>=2.0.9
- arviz>=0.11.2
- cachetools>=4.2.1
- dill
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- ipython>=7.16
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/environment-dev-py39.yml
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@ dependencies:
- aesara>=2.0.9
- arviz>=0.11.2
- cachetools>=4.2.1
- dill
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- ipython>=7.16
Expand Down
2 changes: 1 addition & 1 deletion conda-envs/windows-environment-dev-py38.yml
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@ dependencies:
- aesara>=2.0.9
- arviz>=0.11.2
- cachetools>=4.2.1
- dill
- cloudpickle
- fastprogress>=0.2.0
- h5py>=2.7
- libpython
Expand Down
22 changes: 0 additions & 22 deletions pymc3/distributions/distribution.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,7 +23,6 @@

import aesara
import aesara.tensor as at
import dill

from aesara.tensor.random.op import RandomVariable
from aesara.tensor.random.var import RandomStateSharedVariable
Expand Down Expand Up @@ -533,26 +532,5 @@ def __init__(
self.wrap_random_with_dist_shape = wrap_random_with_dist_shape
self.check_shape_in_random = check_shape_in_random

def __getstate__(self):
# We use dill to serialize the logp function, as this is almost
# always defined in the notebook and won't be pickled correctly.
# Fix https://github.com/pymc-devs/pymc3/issues/3844
try:
logp = dill.dumps(self.logp)
except RecursionError as err:
if type(self.logp) == types.MethodType:
raise ValueError(
"logp for DensityDist is a bound method, leading to RecursionError while serializing"
) from err
else:
raise err
vals = self.__dict__.copy()
vals["logp"] = logp
return vals

def __setstate__(self, vals):
vals["logp"] = dill.loads(vals["logp"])
self.__dict__ = vals

def _distr_parameters_for_repr(self):
return []
37 changes: 6 additions & 31 deletions pymc3/parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,13 +16,13 @@
import logging
import multiprocessing
import multiprocessing.sharedctypes
import pickle
import platform
import time
import traceback

from collections import namedtuple

import cloudpickle
import numpy as np

from fastprogress.fastprogress import progress_bar
Expand Down Expand Up @@ -93,7 +93,6 @@ def __init__(
draws: int,
tune: int,
seed,
pickle_backend,
):
self._msg_pipe = msg_pipe
self._step_method = step_method
Expand All @@ -103,7 +102,6 @@ def __init__(
self._at_seed = seed + 1
self._draws = draws
self._tune = tune
self._pickle_backend = pickle_backend

def _unpickle_step_method(self):
unpickle_error = (
Expand All @@ -112,22 +110,10 @@ def _unpickle_step_method(self):
"or forkserver."
)
if self._step_method_is_pickled:
if self._pickle_backend == "pickle":
try:
self._step_method = pickle.loads(self._step_method)
except Exception:
raise ValueError(unpickle_error)
elif self._pickle_backend == "dill":
try:
import dill
except ImportError:
raise ValueError("dill must be installed for pickle_backend='dill'.")
try:
self._step_method = dill.loads(self._step_method)
except Exception:
raise ValueError(unpickle_error)
else:
raise ValueError("Unknown pickle backend")
try:
self._step_method = cloudpickle.loads(self._step_method)
except Exception:
raise ValueError(unpickle_error)

def run(self):
try:
Expand Down Expand Up @@ -243,7 +229,6 @@ def __init__(
seed,
start,
mp_ctx,
pickle_backend,
):
self.chain = chain
process_name = "worker_chain_%s" % chain
Expand Down Expand Up @@ -287,7 +272,6 @@ def __init__(
draws,
tune,
seed,
pickle_backend,
),
)
self._process.start()
Expand Down Expand Up @@ -406,7 +390,6 @@ def __init__(
start_chain_num: int = 0,
progressbar: bool = True,
mp_ctx=None,
pickle_backend: str = "pickle",
):

if any(len(arg) != chains for arg in [seeds, start_points]):
Expand All @@ -420,14 +403,7 @@ def __init__(

step_method_pickled = None
if mp_ctx.get_start_method() != "fork":
if pickle_backend == "pickle":
step_method_pickled = pickle.dumps(step_method, protocol=-1)
elif pickle_backend == "dill":
try:
import dill
except ImportError:
raise ValueError("dill must be installed for pickle_backend='dill'.")
step_method_pickled = dill.dumps(step_method, protocol=-1)
step_method_pickled = cloudpickle.dumps(step_method, protocol=-1)

self._samplers = [
ProcessAdapter(
Expand All @@ -439,7 +415,6 @@ def __init__(
seed,
start,
mp_ctx,
pickle_backend,
)
for chain, seed, start in zip(range(chains), seeds, start_points)
]
Expand Down
13 changes: 3 additions & 10 deletions pymc3/sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@
from typing import Any, Dict, Iterable, List, Optional, Set, Union, cast

import aesara.gradient as tg
import cloudpickle
import numpy as np
import xarray

Expand Down Expand Up @@ -268,7 +269,6 @@ def sample(
return_inferencedata=None,
idata_kwargs: dict = None,
mp_ctx=None,
pickle_backend: str = "pickle",
**kwargs,
):
r"""Draw samples from the posterior using the given step methods.
Expand Down Expand Up @@ -362,10 +362,6 @@ def sample(
mp_ctx : multiprocessing.context.BaseContent
A multiprocessing context for parallel sampling. See multiprocessing
documentation for details.
pickle_backend : str
One of `'pickle'` or `'dill'`. The library used to pickle models
in parallel sampling if the multiprocessing context is not of type
`fork`.
Returns
-------
Expand Down Expand Up @@ -548,7 +544,6 @@ def sample(
"discard_tuned_samples": discard_tuned_samples,
}
parallel_args = {
"pickle_backend": pickle_backend,
"mp_ctx": mp_ctx,
}

Expand Down Expand Up @@ -1100,7 +1095,7 @@ def __init__(self, steppers, parallelize, progressbar=True):
enumerate(progress_bar(steppers)) if progressbar else enumerate(steppers)
):
secondary_end, primary_end = multiprocessing.Pipe()
stepper_dumps = pickle.dumps(stepper, protocol=4)
stepper_dumps = cloudpickle.dumps(stepper, protocol=4)
process = multiprocessing.Process(
target=self.__class__._run_secondary,
args=(c, stepper_dumps, secondary_end),
Expand Down Expand Up @@ -1159,7 +1154,7 @@ def _run_secondary(c, stepper_dumps, secondary_end):
# re-seed each child process to make them unique
np.random.seed(None)
try:
stepper = pickle.loads(stepper_dumps)
stepper = cloudpickle.loads(stepper_dumps)
# the stepper is not necessarily a PopulationArraySharedStep itself,
# but rather a CompoundStep. PopulationArrayStepShared.population
# has to be updated, therefore we identify the substeppers first.
Expand Down Expand Up @@ -1418,7 +1413,6 @@ def _mp_sample(
callback=None,
discard_tuned_samples=True,
mp_ctx=None,
pickle_backend="pickle",
**kwargs,
):
"""Main iteration for multiprocess sampling.
Expand Down Expand Up @@ -1491,7 +1485,6 @@ def _mp_sample(
chain,
progressbar,
mp_ctx=mp_ctx,
pickle_backend=pickle_backend,
)
try:
try:
Expand Down
4 changes: 2 additions & 2 deletions pymc3/tests/test_distributions.py
Original file line number Diff line number Diff line change
Expand Up @@ -3109,9 +3109,9 @@ def func(x):
y = pm.DensityDist("y", func)
pm.sample(draws=5, tune=1, mp_ctx="spawn")

import pickle
import cloudpickle

pickle.loads(pickle.dumps(y))
cloudpickle.loads(cloudpickle.dumps(y))


def test_distinct_rvs():
Expand Down
8 changes: 4 additions & 4 deletions pymc3/tests/test_minibatches.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,9 +13,9 @@
# limitations under the License.

import itertools
import pickle

import aesara
import cloudpickle
import numpy as np
import pytest

Expand Down Expand Up @@ -132,10 +132,10 @@ def gen():

def test_pickling(self, datagen):
gen = generator(datagen)
pickle.loads(pickle.dumps(gen))
cloudpickle.loads(cloudpickle.dumps(gen))
bad_gen = generator(integers())
with pytest.raises(Exception):
pickle.dumps(bad_gen)
with pytest.raises(TypeError):
cloudpickle.dumps(bad_gen)

def test_gen_cloning_with_shape_change(self, datagen):
gen = generator(datagen)
Expand Down
10 changes: 3 additions & 7 deletions pymc3/tests/test_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,14 +11,14 @@
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import pickle
import unittest

from functools import reduce

import aesara
import aesara.sparse as sparse
import aesara.tensor as at
import cloudpickle
import numpy as np
import numpy.ma as ma
import numpy.testing as npt
Expand Down Expand Up @@ -407,9 +407,7 @@ def test_model_pickle(tmpdir):
x = pm.Normal("x")
pm.Normal("y", observed=1)

file_path = tmpdir.join("model.p")
with open(file_path, "wb") as buff:
pickle.dump(model, buff)
cloudpickle.loads(cloudpickle.dumps(model))


def test_model_pickle_deterministic(tmpdir):
Expand All @@ -420,9 +418,7 @@ def test_model_pickle_deterministic(tmpdir):
pm.Deterministic("w", x / z)
pm.Normal("y", observed=1)

file_path = tmpdir.join("model.p")
with open(file_path, "wb") as buff:
pickle.dump(model, buff)
cloudpickle.loads(cloudpickle.dumps(model))


def test_model_vars():
Expand Down
8 changes: 0 additions & 8 deletions pymc3/tests/test_parallel_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -71,12 +71,6 @@ def _crash_remote_process(a, master_pid):
return 2 * np.array(a)


def test_dill():
with pm.Model():
pm.Normal("x")
pm.sample(tune=1, draws=1, chains=2, cores=2, pickle_backend="dill", mp_ctx="spawn")


def test_remote_pipe_closed():
master_pid = os.getpid()
with pm.Model():
Expand Down Expand Up @@ -112,7 +106,6 @@ def test_abort():
mp_ctx=ctx,
start={"a": np.array([1.0]), "b_log__": np.array(2.0)},
step_method_pickled=None,
pickle_backend="pickle",
)
proc.start()
while True:
Expand Down Expand Up @@ -147,7 +140,6 @@ def test_explicit_sample():
mp_ctx=ctx,
start={"a": np.array([1.0]), "b_log__": np.array(2.0)},
step_method_pickled=None,
pickle_backend="pickle",
)
proc.start()
while True:
Expand Down
6 changes: 4 additions & 2 deletions pymc3/tests/test_pickling.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,8 @@
import pickle
import traceback

import cloudpickle

from pymc3.tests.models import simple_model


Expand All @@ -26,8 +28,8 @@ def test_model_roundtrip(self):
m = self.model
for proto in range(pickle.HIGHEST_PROTOCOL + 1):
try:
s = pickle.dumps(m, proto)
pickle.loads(s)
s = cloudpickle.dumps(m, proto)
cloudpickle.loads(s)
except Exception:
raise AssertionError(
"Exception while trying roundtrip with pickle protocol %d:\n" % proto
Expand Down
Loading

0 comments on commit 4f0ce74

Please sign in to comment.