Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

make learners picklable #264

Merged
merged 23 commits into from
Apr 24, 2020
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
Show all changes
23 commits
Select commit Hold shift + click to select a range
f75b099
make Learner1D picklable
basnijholt Apr 9, 2020
77f3613
make Learner2D picklable
basnijholt Apr 9, 2020
06c1dd2
make AverageLearner picklable
basnijholt Apr 9, 2020
062c2f7
make IntegratorLearner picklable
basnijholt Apr 9, 2020
cf7dad4
make SequenceLearner picklable
basnijholt Apr 9, 2020
b71a058
make BalancingLearner picklable
basnijholt Apr 9, 2020
7ff01d2
make DataSaver picklable
basnijholt Apr 9, 2020
99fadf1
add tests for pickling
basnijholt Apr 10, 2020
5831592
add cloudpickle to testing dependencies
basnijholt Apr 10, 2020
dfd8b0c
test serialization with pickle, cloudpickle, and dill
basnijholt Apr 10, 2020
d6172e0
only test cloudpickle and dill if installed
basnijholt Apr 10, 2020
64bb2e6
test for idential ask and loss response
basnijholt Apr 12, 2020
7bc0ade
add flaky
basnijholt Apr 12, 2020
978a62c
use an exact equality in checking the number of points
basnijholt Apr 12, 2020
1a31669
set learner._recompute_losses_factor = 1
basnijholt Apr 13, 2020
2e40ebb
use exact equalities
basnijholt Apr 13, 2020
44c6446
make Learner1D's datastructures identical before and after pickling
basnijholt Apr 14, 2020
f7a3b03
make Learner2D's datastructures identical before and after pickling
basnijholt Apr 14, 2020
30619d5
do not specially treat Learner1D and Learner2D
basnijholt Apr 14, 2020
9dccd05
test for more points
basnijholt Apr 14, 2020
1e4c495
refactor tests
basnijholt Apr 15, 2020
ca28f2e
do not initialize child-learners twice in BalancingLearner
basnijholt Apr 15, 2020
acc5400
do not initialize child-learners twice in DataSaver
basnijholt Apr 15, 2020
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
13 changes: 13 additions & 0 deletions adaptive/learner/average_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -144,3 +144,16 @@ def _get_data(self):

def _set_data(self, data):
self.data, self.npoints, self.sum_f, self.sum_f_sq = data

def __getstate__(self):
return (
self.function,
self.atol,
self.rtol,
self._get_data(),
)

def __setstate__(self, state):
function, atol, rtol, data = state
self.__init__(function, atol, rtol)
self._set_data(data)
11 changes: 11 additions & 0 deletions adaptive/learner/balancing_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -440,3 +440,14 @@ def _get_data(self):
def _set_data(self, data):
for l, _data in zip(self.learners, data):
l._set_data(_data)

def __getstate__(self):
return (
self.learners,
self._cdims_default,
self.strategy,
)

def __setstate__(self, state):
learners, cdims, strategy = state
self.__init__(learners, cdims=cdims, strategy=strategy)
12 changes: 12 additions & 0 deletions adaptive/learner/data_saver.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,6 +51,18 @@ def _set_data(self, data):
learner_data, self.extra_data = data
self.learner._set_data(learner_data)

def __getstate__(self):
return (
self.learner,
self.arg_picker,
self.extra_data,
)

def __setstate__(self, state):
learner, arg_picker, extra_data = state
self.__init__(learner, arg_picker)
self.extra_data = extra_data

@copy_docstring_from(BaseLearner.save)
def save(self, fname, compress=True):
# We copy this method because the 'DataSaver' is not a
Expand Down
13 changes: 13 additions & 0 deletions adaptive/learner/integrator_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -591,3 +591,16 @@ def _set_data(self, data):
self.x_mapping = defaultdict(lambda: SortedSet([], key=attrgetter("rdepth")))
for k, _set in x_mapping.items():
self.x_mapping[k].update(_set)

def __getstate__(self):
return (
self.function,
self.bounds,
self.tol,
self._get_data(),
)

def __setstate__(self, state):
function, bounds, tol, data = state
self.__init__(function, bounds, tol)
self._set_data(data)
17 changes: 17 additions & 0 deletions adaptive/learner/learner1D.py
Original file line number Diff line number Diff line change
Expand Up @@ -625,6 +625,23 @@ def _set_data(self, data):
if data:
self.tell_many(*zip(*data.items()))

def __getstate__(self):
return (
self.function,
self.bounds,
self.loss_per_interval,
dict(self.losses), # SortedDict cannot be pickled
dict(self.losses_combined), # ItemSortedDict cannot be pickled
self._get_data(),
)

def __setstate__(self, state):
function, bounds, loss_per_interval, losses, losses_combined, data = state
self.__init__(function, bounds, loss_per_interval)
self._set_data(data)
self.losses.update(losses)
self.losses_combined.update(losses_combined)


def loss_manager(x_scale):
def sort_key(ival, loss):
Expand Down
15 changes: 15 additions & 0 deletions adaptive/learner/learner2D.py
Original file line number Diff line number Diff line change
Expand Up @@ -706,3 +706,18 @@ def _set_data(self, data):
for point in copy(self._stack):
if point in self.data:
self._stack.pop(point)

def __getstate__(self):
return (
self.function,
self.bounds,
self.loss_per_triangle,
self._stack,
self._get_data(),
)

def __setstate__(self, state):
function, bounds, loss_per_triangle, _stack, data = state
self.__init__(function, bounds, loss_per_triangle)
self._set_data(data)
self._stack = _stack
32 changes: 22 additions & 10 deletions adaptive/learner/sequence_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,16 +83,6 @@ def ask(self, n, tell_pending=True):

return points, loss_improvements

def _get_data(self):
return self.data

def _set_data(self, data):
if data:
indices, values = zip(*data.items())
# the points aren't used by tell, so we can safely pass None
points = [(i, None) for i in indices]
self.tell_many(points, values)

def loss(self, real=True):
if not (self._to_do_indices or self.pending_points):
return 0
Expand Down Expand Up @@ -128,3 +118,25 @@ def result(self):
@property
def npoints(self):
return len(self.data)

def _get_data(self):
return self.data

def _set_data(self, data):
if data:
indices, values = zip(*data.items())
# the points aren't used by tell, so we can safely pass None
points = [(i, None) for i in indices]
self.tell_many(points, values)

def __getstate__(self):
return (
self._original_function,
self.sequence,
self._get_data(),
)

def __setstate__(self, state):
function, sequence, data = state
self.__init__(function, sequence)
self._set_data(data)
117 changes: 117 additions & 0 deletions adaptive/tests/test_pickling.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,117 @@
import pickle

import pytest

from adaptive.learner import (
AverageLearner,
BalancingLearner,
DataSaver,
IntegratorLearner,
Learner1D,
Learner2D,
SequenceLearner,
)
from adaptive.runner import simple

try:
import cloudpickle

with_cloudpickle = True
except ModuleNotFoundError:
with_cloudpickle = False

try:
import dill

with_dill = True
except ModuleNotFoundError:
with_dill = False


def goal_1(learner):
return learner.npoints == 10


def goal_2(learner):
return learner.npoints == 20


def pickleable_f(x):
return hash(str(x)) / 2 ** 63


nonpickleable_f = lambda x: hash(str(x)) / 2 ** 63 # noqa: E731


def identity_function(x):
return x


def datasaver(f, learner_type, learner_kwargs):
return DataSaver(
learner=learner_type(f, **learner_kwargs), arg_picker=identity_function
)


def balancing_learner(f, learner_type, learner_kwargs):
learner_1 = learner_type(f, **learner_kwargs)
learner_2 = learner_type(f, **learner_kwargs)
return BalancingLearner([learner_1, learner_2])


learners_pairs = [
(Learner1D, dict(bounds=(-1, 1))),
(Learner2D, dict(bounds=[(-1, 1), (-1, 1)])),
(SequenceLearner, dict(sequence=list(range(100)))),
(IntegratorLearner, dict(bounds=(0, 1), tol=1e-3)),
(AverageLearner, dict(atol=0.1)),
(datasaver, dict(learner_type=Learner1D, learner_kwargs=dict(bounds=(-1, 1)))),
(
balancing_learner,
dict(learner_type=Learner1D, learner_kwargs=dict(bounds=(-1, 1))),
),
]

serializers = [(pickle, pickleable_f)]
if with_cloudpickle:
serializers.append((cloudpickle, nonpickleable_f))
if with_dill:
serializers.append((dill, nonpickleable_f))


learners = [
(learner_type, learner_kwargs, serializer, f)
for serializer, f in serializers
for learner_type, learner_kwargs in learners_pairs
]


@pytest.mark.parametrize(
"learner_type, learner_kwargs, serializer, f", learners,
)
def test_serialization_for(learner_type, learner_kwargs, serializer, f):
"""Test serializing a learner using different serializers."""

learner = learner_type(f, **learner_kwargs)

simple(learner, goal_1)
learner_bytes = serializer.dumps(learner)
loss = learner.loss()
asked = learner.ask(10)
data = learner.data

del f
del learner

learner_loaded = serializer.loads(learner_bytes)
assert learner_loaded.npoints == 10
assert loss == learner_loaded.loss()
assert data == learner_loaded.data

assert asked == learner_loaded.ask(10)

# load again to undo the ask
learner_loaded = serializer.loads(learner_bytes)

simple(learner_loaded, goal_2)
assert learner_loaded.npoints == 20
4 changes: 3 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,8 +51,10 @@ def get_version_and_cmdclass(package_name):
"pre_commit",
],
"other": [
"ipyparallel>=6.2.5", # because of https://github.com/ipython/ipyparallel/issues/404
"cloudpickle",
"dill",
"distributed",
"ipyparallel>=6.2.5", # because of https://github.com/ipython/ipyparallel/issues/404
"loky",
"scikit-optimize",
"wexpect" if os.name == "nt" else "pexpect",
Expand Down