Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Make SequenceLearner points hashable by passing the sequence to the function. #266

Closed
wants to merge 9 commits into from
42 changes: 18 additions & 24 deletions adaptive/learner/sequence_learner.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@
from adaptive.learner.base_learner import BaseLearner


class _IgnoreFirstArgument:
"""Remove the first argument from the call signature.
class _IndexToPoint:
basnijholt marked this conversation as resolved.
Show resolved Hide resolved
"""Call function with index of sequence.

The SequenceLearner's function receives a tuple ``(index, point)``
basnijholt marked this conversation as resolved.
Show resolved Hide resolved
but the original function only takes ``point``.
Expand All @@ -15,18 +15,19 @@ class _IgnoreFirstArgument:
pickable.
"""

def __init__(self, function):
def __init__(self, function, sequence):
self.function = function
self.sequence = sequence

def __call__(self, index_point, *args, **kwargs):
index, point = index_point
def __call__(self, index, *args, **kwargs):
point = self.sequence[index]
return self.function(point, *args, **kwargs)
basnijholt marked this conversation as resolved.
Show resolved Hide resolved

def __getstate__(self):
return self.function
return self.function, self.sequence

def __setstate__(self, function):
self.__init__(function)
def __setstate__(self, state):
self.__init__(*state)


class SequenceLearner(BaseLearner):
Expand All @@ -40,7 +41,7 @@ class SequenceLearner(BaseLearner):
Parameters
----------
function : callable
The function to learn. Must take a single element `sequence`.
The function to learn. Must take a single element of `sequence`.
sequence : sequence
The sequence to learn.

Expand All @@ -58,7 +59,7 @@ class SequenceLearner(BaseLearner):

def __init__(self, function, sequence):
self._original_function = function
self.function = _IgnoreFirstArgument(function)
self.function = _IndexToPoint(function, sequence)
self._to_do_indices = SortedSet({i for i, _ in enumerate(sequence)})
self._ntotal = len(sequence)
self.sequence = copy(sequence)
Expand All @@ -67,31 +68,26 @@ def __init__(self, function, sequence):

def ask(self, n, tell_pending=True):
indices = []
points = []
loss_improvements = []
for index in self._to_do_indices:
if len(points) >= n:
if len(indices) >= n:
break
point = self.sequence[index]
indices.append(index)
points.append((index, point))
loss_improvements.append(1 / self._ntotal)

if tell_pending:
for i, p in zip(indices, points):
self.tell_pending((i, p))
for index in indices:
self.tell_pending(index)

return points, loss_improvements
return indices, loss_improvements

def _get_data(self):
return self.data

def _set_data(self, data):
if data:
indices, values = zip(*data.items())
# the points aren't used by tell, so we can safely pass None
points = [(i, None) for i in indices]
self.tell_many(points, values)
self.tell_many(indices, values)

def loss(self, real=True):
if not (self._to_do_indices or self.pending_points):
Expand All @@ -105,14 +101,12 @@ def remove_unfinished(self):
self._to_do_indices.add(i)
self.pending_points = set()

def tell(self, point, value):
index, point = point
def tell(self, index, value):
self.data[index] = value
self.pending_points.discard(index)
self._to_do_indices.discard(index)

def tell_pending(self, point):
index, point = point
def tell_pending(self, index):
self.pending_points.add(index)
self._to_do_indices.discard(index)

Expand Down
31 changes: 5 additions & 26 deletions adaptive/tests/test_learners.py
Original file line number Diff line number Diff line change
Expand Up @@ -281,19 +281,9 @@ def test_adding_existing_data_is_idempotent(learner_type, f, learner_kwargs):
M = random.randint(10, 30)
pls = zip(*learner.ask(M))
cpls = zip(*control.ask(M))
if learner_type is SequenceLearner:
# The SequenceLearner's points might not be hasable
points, values = zip(*pls)
indices, points = zip(*points)

cpoints, cvalues = zip(*cpls)
cindices, cpoints = zip(*cpoints)
assert (np.array(points) == np.array(cpoints)).all()
assert values == cvalues
assert indices == cindices
else:
# Point ordering is not defined, so compare as sets
assert set(pls) == set(cpls)
# Point ordering is not defined, so compare as sets
assert set(pls) == set(cpls)


# XXX: This *should* pass (https://github.com/python-adaptive/adaptive/issues/55)
Expand Down Expand Up @@ -324,20 +314,9 @@ def test_adding_non_chosen_data(learner_type, f, learner_kwargs):
pls = zip(*learner.ask(M))
cpls = zip(*control.ask(M))

if learner_type is SequenceLearner:
# The SequenceLearner's points might not be hasable
points, values = zip(*pls)
indices, points = zip(*points)

cpoints, cvalues = zip(*cpls)
cindices, cpoints = zip(*cpoints)
assert (np.array(points) == np.array(cpoints)).all()
assert values == cvalues
assert indices == cindices
else:
# Point ordering within a single call to 'ask'
# is not guaranteed to be the same by the API.
assert set(pls) == set(cpls)
# Point ordering within a single call to 'ask'
# is not guaranteed to be the same by the API.
assert set(pls) == set(cpls)


@run_with(Learner1D, xfail(Learner2D), xfail(LearnerND), AverageLearner)
Expand Down
24 changes: 24 additions & 0 deletions adaptive/tests/test_sequence_learner.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
import asyncio

import numpy as np

import adaptive


def might_fail(dct):
basnijholt marked this conversation as resolved.
Show resolved Hide resolved
import random

if random.random() < 0.5:
raise Exception()
return dct["x"]


def test_fail_with_sequence_of_unhashable():
# https://github.com/python-adaptive/adaptive/issues/265
seq = [dict(x=x) for x in np.linspace(-1, 1, 101)] # unhashable
basnijholt marked this conversation as resolved.
Show resolved Hide resolved
learner = adaptive.SequenceLearner(might_fail, sequence=seq)
runner = adaptive.Runner(
learner, goal=adaptive.SequenceLearner.done, retries=100
) # with 100 retries the test will fail once in 10^31
asyncio.get_event_loop().run_until_complete(runner.task)
assert runner.status() == "finished"