Skip to content

Commit

Permalink
Merge pull request #268 from python-adaptive/unhashable-runner-points
Browse files Browse the repository at this point in the history
make the Runner work with unhashable points
  • Loading branch information
basnijholt authored Apr 24, 2020
2 parents de0cc0c + c7a12a4 commit 5754320
Show file tree
Hide file tree
Showing 2 changed files with 80 additions and 50 deletions.
124 changes: 77 additions & 47 deletions adaptive/runner.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,9 @@
import abc
import asyncio
import concurrent.futures as concurrent
import functools
import inspect
import itertools
import pickle
import sys
import time
Expand Down Expand Up @@ -91,14 +93,14 @@ class BaseRunner(metaclass=abc.ABCMeta):
log : list or None
Record of the method calls made to the learner, in the format
``(method_name, *args)``.
to_retry : dict
Mapping of ``{point: n_fails, ...}``. When a point has failed
to_retry : list of tuples
List of ``(point, n_fails)``. When a point has failed
``runner.retries`` times it is removed but will be present
in ``runner.tracebacks``.
tracebacks : dict
A mapping of point to the traceback if that point failed.
pending_points : dict
A mapping of `~concurrent.futures.Future`\s to points.
tracebacks : list of tuples
List of of ``(point, tb)`` for points that failed.
pending_points : list of tuples
A list of tuples with ``(concurrent.futures.Future, point)``.
Methods
-------
Expand Down Expand Up @@ -126,7 +128,7 @@ def __init__(

self._max_tasks = ntasks

self.pending_points = {}
self._pending_tasks = {} # mapping from concurrent.futures.Future → point id

# if we instantiate our own executor, then we are also responsible
# for calling 'shutdown'
Expand All @@ -143,14 +145,20 @@ def __init__(
# Error handling attributes
self.retries = retries
self.raise_if_retries_exceeded = raise_if_retries_exceeded
self.to_retry = {}
self.tracebacks = {}
self._to_retry = {}
self._tracebacks = {}

self._id_to_point = {}
self._next_id = functools.partial(
next, itertools.count()
) # some unique id to be associated with each point

def _get_max_tasks(self):
return self._max_tasks or _get_ncores(self.executor)

def _do_raise(self, e, x):
tb = self.tracebacks[x]
def _do_raise(self, e, i):
tb = self._tracebacks[i]
x = self._id_to_point[i]
raise RuntimeError(
"An error occured while evaluating "
f'"learner.function({x})". '
Expand All @@ -162,15 +170,21 @@ def do_log(self):
return self.log is not None

def _ask(self, n):
points = [
p for p in self.to_retry.keys() if p not in self.pending_points.values()
][:n]
loss_improvements = len(points) * [float("inf")]
if len(points) < n:
new_points, new_losses = self.learner.ask(n - len(points))
points += new_points
pending_ids = self._pending_tasks.values()
# using generator here because we only need until `n`
pids_gen = (pid for pid in self._to_retry.keys() if pid not in pending_ids)
pids = list(itertools.islice(pids_gen, n))

loss_improvements = len(pids) * [float("inf")]

if len(pids) < n:
new_points, new_losses = self.learner.ask(n - len(pids))
loss_improvements += new_losses
return points, loss_improvements
for point in new_points:
pid = self._next_id()
self._id_to_point[pid] = point
pids.append(pid)
return pids, loss_improvements

def overhead(self):
"""Overhead of using Adaptive and the executor in percent.
Expand All @@ -197,21 +211,22 @@ def overhead(self):

def _process_futures(self, done_futs):
for fut in done_futs:
x = self.pending_points.pop(fut)
pid = self._pending_tasks.pop(fut)
try:
y = fut.result()
t = time.time() - fut.start_time # total execution time
except Exception as e:
self.tracebacks[x] = traceback.format_exc()
self.to_retry[x] = self.to_retry.get(x, 0) + 1
if self.to_retry[x] > self.retries:
self.to_retry.pop(x)
self._tracebacks[pid] = traceback.format_exc()
self._to_retry[pid] = self._to_retry.get(pid, 0) + 1
if self._to_retry[pid] > self.retries:
self._to_retry.pop(pid)
if self.raise_if_retries_exceeded:
self._do_raise(e, x)
self._do_raise(e, pid)
else:
self._elapsed_function_time += t / self._get_max_tasks()
self.to_retry.pop(x, None)
self.tracebacks.pop(x, None)
self._to_retry.pop(pid, None)
self._tracebacks.pop(pid, None)
x = self._id_to_point.pop(pid)
if self.do_log:
self.log.append(("tell", x, y))
self.learner.tell(x, y)
Expand All @@ -220,28 +235,29 @@ def _get_futures(self):
# Launch tasks to replace the ones that completed
# on the last iteration, making sure to fill workers
# that have started since the last iteration.
n_new_tasks = max(0, self._get_max_tasks() - len(self.pending_points))
n_new_tasks = max(0, self._get_max_tasks() - len(self._pending_tasks))

if self.do_log:
self.log.append(("ask", n_new_tasks))

points, _ = self._ask(n_new_tasks)
pids, _ = self._ask(n_new_tasks)

for x in points:
for pid in pids:
start_time = time.time() # so we can measure execution time
fut = self._submit(x)
point = self._id_to_point[pid]
fut = self._submit(point)
fut.start_time = start_time
self.pending_points[fut] = x
self._pending_tasks[fut] = pid

# Collect and results and add them to the learner
futures = list(self.pending_points.keys())
futures = list(self._pending_tasks.keys())
return futures

def _remove_unfinished(self):
# remove points with 'None' values from the learner
self.learner.remove_unfinished()
# cancel any outstanding tasks
remaining = list(self.pending_points.keys())
remaining = list(self._pending_tasks.keys())
for fut in remaining:
fut.cancel()
return remaining
Expand All @@ -260,7 +276,7 @@ def _cleanup(self):
@property
def failed(self):
"""Set of points that failed ``runner.retries`` times."""
return set(self.tracebacks) - set(self.to_retry)
return set(self._tracebacks) - set(self._to_retry)

@abc.abstractmethod
def elapsed_time(self):
Expand All @@ -276,6 +292,20 @@ def _submit(self, x):
"""Is called in `_get_futures`."""
pass

@property
def tracebacks(self):
return [(self._id_to_point[pid], tb) for pid, tb in self._tracebacks.items()]

@property
def to_retry(self):
return [(self._id_to_point[pid], n) for pid, n in self._to_retry.items()]

@property
def pending_points(self):
return [
(fut, self._id_to_point[pid]) for fut, pid in self._pending_tasks.items()
]


class BlockingRunner(BaseRunner):
"""Run a learner synchronously in an executor.
Expand Down Expand Up @@ -315,14 +345,14 @@ class BlockingRunner(BaseRunner):
log : list or None
Record of the method calls made to the learner, in the format
``(method_name, *args)``.
to_retry : dict
Mapping of ``{point: n_fails, ...}``. When a point has failed
to_retry : list of tuples
List of ``(point, n_fails)``. When a point has failed
``runner.retries`` times it is removed but will be present
in ``runner.tracebacks``.
tracebacks : dict
A mapping of point to the traceback if that point failed.
pending_points : dict
A mapping of `~concurrent.futures.Future`\to points.
tracebacks : list of tuples
List of of ``(point, tb)`` for points that failed.
pending_points : list of tuples
A list of tuples with ``(concurrent.futures.Future, point)``.
Methods
-------
Expand Down Expand Up @@ -438,14 +468,14 @@ class AsyncRunner(BaseRunner):
log : list or None
Record of the method calls made to the learner, in the format
``(method_name, *args)``.
to_retry : dict
Mapping of ``{point: n_fails, ...}``. When a point has failed
to_retry : list of tuples
List of ``(point, n_fails)``. When a point has failed
``runner.retries`` times it is removed but will be present
in ``runner.tracebacks``.
tracebacks : dict
A mapping of point to the traceback if that point failed.
pending_points : dict
A mapping of `~concurrent.futures.Future`\s to points.
tracebacks : list of tuples
List of of ``(point, tb)`` for points that failed.
pending_points : list of tuples
A list of tuples with ``(concurrent.futures.Future, point)``.
Methods
-------
Expand Down
6 changes: 3 additions & 3 deletions docs/source/tutorial/tutorial.advanced-topics.rst
Original file line number Diff line number Diff line change
Expand Up @@ -297,12 +297,12 @@ raise the exception with the stack trace:
runner.task.result()


You can also check ``runner.tracebacks`` which is a mapping from
pointtraceback.
You can also check ``runner.tracebacks`` which is a list of tuples with
(point, traceback).

.. jupyter-execute::

for point, tb in runner.tracebacks.items():
for point, tb in runner.tracebacks:
print(f'point: {point}:\n {tb}')

Logging runners
Expand Down

0 comments on commit 5754320

Please sign in to comment.