diff --git a/adaptive/runner.py b/adaptive/runner.py index 9a8a9b1a5..73fd9ed94 100644 --- a/adaptive/runner.py +++ b/adaptive/runner.py @@ -1,7 +1,9 @@ import abc import asyncio import concurrent.futures as concurrent +import functools import inspect +import itertools import pickle import sys import time @@ -91,14 +93,14 @@ class BaseRunner(metaclass=abc.ABCMeta): log : list or None Record of the method calls made to the learner, in the format ``(method_name, *args)``. - to_retry : dict - Mapping of ``{point: n_fails, ...}``. When a point has failed + to_retry : list of tuples + List of ``(point, n_fails)``. When a point has failed ``runner.retries`` times it is removed but will be present in ``runner.tracebacks``. - tracebacks : dict - A mapping of point to the traceback if that point failed. - pending_points : dict - A mapping of `~concurrent.futures.Future`\s to points. + tracebacks : list of tuples + List of of ``(point, tb)`` for points that failed. + pending_points : list of tuples + A list of tuples with ``(concurrent.futures.Future, point)``. Methods ------- @@ -126,7 +128,7 @@ def __init__( self._max_tasks = ntasks - self.pending_points = {} + self._pending_tasks = {} # mapping from concurrent.futures.Future → point id # if we instantiate our own executor, then we are also responsible # for calling 'shutdown' @@ -143,14 +145,20 @@ def __init__( # Error handling attributes self.retries = retries self.raise_if_retries_exceeded = raise_if_retries_exceeded - self.to_retry = {} - self.tracebacks = {} + self._to_retry = {} + self._tracebacks = {} + + self._id_to_point = {} + self._next_id = functools.partial( + next, itertools.count() + ) # some unique id to be associated with each point def _get_max_tasks(self): return self._max_tasks or _get_ncores(self.executor) - def _do_raise(self, e, x): - tb = self.tracebacks[x] + def _do_raise(self, e, i): + tb = self._tracebacks[i] + x = self._id_to_point[i] raise RuntimeError( "An error occured while evaluating " f'"learner.function({x})". ' @@ -162,15 +170,21 @@ def do_log(self): return self.log is not None def _ask(self, n): - points = [ - p for p in self.to_retry.keys() if p not in self.pending_points.values() - ][:n] - loss_improvements = len(points) * [float("inf")] - if len(points) < n: - new_points, new_losses = self.learner.ask(n - len(points)) - points += new_points + pending_ids = self._pending_tasks.values() + # using generator here because we only need until `n` + pids_gen = (pid for pid in self._to_retry.keys() if pid not in pending_ids) + pids = list(itertools.islice(pids_gen, n)) + + loss_improvements = len(pids) * [float("inf")] + + if len(pids) < n: + new_points, new_losses = self.learner.ask(n - len(pids)) loss_improvements += new_losses - return points, loss_improvements + for point in new_points: + pid = self._next_id() + self._id_to_point[pid] = point + pids.append(pid) + return pids, loss_improvements def overhead(self): """Overhead of using Adaptive and the executor in percent. @@ -197,21 +211,22 @@ def overhead(self): def _process_futures(self, done_futs): for fut in done_futs: - x = self.pending_points.pop(fut) + pid = self._pending_tasks.pop(fut) try: y = fut.result() t = time.time() - fut.start_time # total execution time except Exception as e: - self.tracebacks[x] = traceback.format_exc() - self.to_retry[x] = self.to_retry.get(x, 0) + 1 - if self.to_retry[x] > self.retries: - self.to_retry.pop(x) + self._tracebacks[pid] = traceback.format_exc() + self._to_retry[pid] = self._to_retry.get(pid, 0) + 1 + if self._to_retry[pid] > self.retries: + self._to_retry.pop(pid) if self.raise_if_retries_exceeded: - self._do_raise(e, x) + self._do_raise(e, pid) else: self._elapsed_function_time += t / self._get_max_tasks() - self.to_retry.pop(x, None) - self.tracebacks.pop(x, None) + self._to_retry.pop(pid, None) + self._tracebacks.pop(pid, None) + x = self._id_to_point.pop(pid) if self.do_log: self.log.append(("tell", x, y)) self.learner.tell(x, y) @@ -220,28 +235,29 @@ def _get_futures(self): # Launch tasks to replace the ones that completed # on the last iteration, making sure to fill workers # that have started since the last iteration. - n_new_tasks = max(0, self._get_max_tasks() - len(self.pending_points)) + n_new_tasks = max(0, self._get_max_tasks() - len(self._pending_tasks)) if self.do_log: self.log.append(("ask", n_new_tasks)) - points, _ = self._ask(n_new_tasks) + pids, _ = self._ask(n_new_tasks) - for x in points: + for pid in pids: start_time = time.time() # so we can measure execution time - fut = self._submit(x) + point = self._id_to_point[pid] + fut = self._submit(point) fut.start_time = start_time - self.pending_points[fut] = x + self._pending_tasks[fut] = pid # Collect and results and add them to the learner - futures = list(self.pending_points.keys()) + futures = list(self._pending_tasks.keys()) return futures def _remove_unfinished(self): # remove points with 'None' values from the learner self.learner.remove_unfinished() # cancel any outstanding tasks - remaining = list(self.pending_points.keys()) + remaining = list(self._pending_tasks.keys()) for fut in remaining: fut.cancel() return remaining @@ -260,7 +276,7 @@ def _cleanup(self): @property def failed(self): """Set of points that failed ``runner.retries`` times.""" - return set(self.tracebacks) - set(self.to_retry) + return set(self._tracebacks) - set(self._to_retry) @abc.abstractmethod def elapsed_time(self): @@ -276,6 +292,20 @@ def _submit(self, x): """Is called in `_get_futures`.""" pass + @property + def tracebacks(self): + return [(self._id_to_point[pid], tb) for pid, tb in self._tracebacks.items()] + + @property + def to_retry(self): + return [(self._id_to_point[pid], n) for pid, n in self._to_retry.items()] + + @property + def pending_points(self): + return [ + (fut, self._id_to_point[pid]) for fut, pid in self._pending_tasks.items() + ] + class BlockingRunner(BaseRunner): """Run a learner synchronously in an executor. @@ -315,14 +345,14 @@ class BlockingRunner(BaseRunner): log : list or None Record of the method calls made to the learner, in the format ``(method_name, *args)``. - to_retry : dict - Mapping of ``{point: n_fails, ...}``. When a point has failed + to_retry : list of tuples + List of ``(point, n_fails)``. When a point has failed ``runner.retries`` times it is removed but will be present in ``runner.tracebacks``. - tracebacks : dict - A mapping of point to the traceback if that point failed. - pending_points : dict - A mapping of `~concurrent.futures.Future`\to points. + tracebacks : list of tuples + List of of ``(point, tb)`` for points that failed. + pending_points : list of tuples + A list of tuples with ``(concurrent.futures.Future, point)``. Methods ------- @@ -438,14 +468,14 @@ class AsyncRunner(BaseRunner): log : list or None Record of the method calls made to the learner, in the format ``(method_name, *args)``. - to_retry : dict - Mapping of ``{point: n_fails, ...}``. When a point has failed + to_retry : list of tuples + List of ``(point, n_fails)``. When a point has failed ``runner.retries`` times it is removed but will be present in ``runner.tracebacks``. - tracebacks : dict - A mapping of point to the traceback if that point failed. - pending_points : dict - A mapping of `~concurrent.futures.Future`\s to points. + tracebacks : list of tuples + List of of ``(point, tb)`` for points that failed. + pending_points : list of tuples + A list of tuples with ``(concurrent.futures.Future, point)``. Methods ------- diff --git a/docs/source/tutorial/tutorial.advanced-topics.rst b/docs/source/tutorial/tutorial.advanced-topics.rst index 84487a9cb..1156979c8 100644 --- a/docs/source/tutorial/tutorial.advanced-topics.rst +++ b/docs/source/tutorial/tutorial.advanced-topics.rst @@ -297,12 +297,12 @@ raise the exception with the stack trace: runner.task.result() -You can also check ``runner.tracebacks`` which is a mapping from -point → traceback. +You can also check ``runner.tracebacks`` which is a list of tuples with +(point, traceback). .. jupyter-execute:: - for point, tb in runner.tracebacks.items(): + for point, tb in runner.tracebacks: print(f'point: {point}:\n {tb}') Logging runners