diff --git a/adaptive/runner.py b/adaptive/runner.py index 2a47c7b75..2e634fcce 100644 --- a/adaptive/runner.py +++ b/adaptive/runner.py @@ -97,12 +97,12 @@ class BaseRunner(metaclass=abc.ABCMeta): log : list or None Record of the method calls made to the learner, in the format ``(method_name, *args)``. - to_retry : dict - Mapping of ``{point: n_fails, ...}``. When a point has failed + to_retry : list of tuples + List of ``(point, n_fails)``. When a point has failed ``runner.retries`` times it is removed but will be present in ``runner.tracebacks``. - tracebacks : dict - A mapping of point to the traceback if that point failed. + tracebacks : list of tuples + List of of ``(point, tb)`` for points that failed. pending_points : dict A mapping of `~concurrent.futures.Future`\s to points. @@ -149,19 +149,19 @@ def __init__( # Error handling attributes self.retries = retries self.raise_if_retries_exceeded = raise_if_retries_exceeded - self.to_retry = {} - self.tracebacks = {} + self._to_retry = {} + self._tracebacks = {} # Keeping track of index -> point - self.index_to_point = {} + self._index_to_point = {} self._i = 0 # some unique index to be associated with each point def _get_max_tasks(self): return self._max_tasks or _get_ncores(self.executor) def _do_raise(self, e, i): - tb = self.tracebacks[i] - x = self.index_to_point[i] + tb = self._tracebacks[i] + x = self._index_to_point[i] raise RuntimeError( "An error occured while evaluating " f'"learner.function({x})". ' @@ -174,10 +174,10 @@ def do_log(self): def _ask(self, n): points = [] - for i, index in enumerate(self.to_retry.keys()): + for i, index in enumerate(self._to_retry.keys()): if i == n: break - point = self.index_to_point[index] + point = self._index_to_point[index] if point not in self.pending_points.values(): points.append(point) @@ -214,22 +214,22 @@ def overhead(self): def _process_futures(self, done_futs): for fut in done_futs: x = self.pending_points.pop(fut) - i = _key_by_value(self.index_to_point, x) # O(N) + i = _key_by_value(self._index_to_point, x) # O(N) try: y = fut.result() t = time.time() - fut.start_time # total execution time except Exception as e: - self.tracebacks[i] = traceback.format_exc() - self.to_retry[i] = self.to_retry.get(i, 0) + 1 - if self.to_retry[i] > self.retries: - self.to_retry.pop(i) + self._tracebacks[i] = traceback.format_exc() + self._to_retry[i] = self._to_retry.get(i, 0) + 1 + if self._to_retry[i] > self.retries: + self._to_retry.pop(i) if self.raise_if_retries_exceeded: self._do_raise(e, i) else: self._elapsed_function_time += t / self._get_max_tasks() - self.to_retry.pop(i, None) - self.tracebacks.pop(i, None) - self.index_to_point.pop(i) + self._to_retry.pop(i, None) + self._tracebacks.pop(i, None) + self._index_to_point.pop(i) if self.do_log: self.log.append(("tell", x, y)) self.learner.tell(x, y) @@ -250,12 +250,12 @@ def _get_futures(self): fut = self._submit(x) fut.start_time = start_time self.pending_points[fut] = x - i = _key_by_value(self.index_to_point, x) # O(N) + i = _key_by_value(self._index_to_point, x) # O(N) if i is None: - # `x` is not a value in `self.index_to_point` + # `x` is not a value in `self._index_to_point` self._i += 1 i = self._i - self.index_to_point[i] = x + self._index_to_point[i] = x # Collect and results and add them to the learner futures = list(self.pending_points.keys()) @@ -284,7 +284,7 @@ def _cleanup(self): @property def failed(self): """Set of points that failed ``runner.retries`` times.""" - return set(self.tracebacks) - set(self.to_retry) + return set(self._tracebacks) - set(self._to_retry) @abc.abstractmethod def elapsed_time(self): @@ -300,6 +300,14 @@ def _submit(self, x): """Is called in `_get_futures`.""" pass + @property + def tracebacks(self): + return [(self._index_to_point[i], tb) for i, tb in self._tracebacks.items()] + + @property + def retries(self): + return [(self._index_to_point[i], n) for i, n in self._tracebacks.items()] + class BlockingRunner(BaseRunner): """Run a learner synchronously in an executor.