diff --git a/adaptive/runner.py b/adaptive/runner.py index 9a8a9b1a5..2a47c7b75 100644 --- a/adaptive/runner.py +++ b/adaptive/runner.py @@ -54,6 +54,12 @@ ) +def _key_by_value(dct, value): + for k, v in dct.items(): + if v == value: + return k + + class BaseRunner(metaclass=abc.ABCMeta): r"""Base class for runners that use `concurrent.futures.Executors`. @@ -146,11 +152,16 @@ def __init__( self.to_retry = {} self.tracebacks = {} + # Keeping track of index -> point + self.index_to_point = {} + self._i = 0 # some unique index to be associated with each point + def _get_max_tasks(self): return self._max_tasks or _get_ncores(self.executor) - def _do_raise(self, e, x): - tb = self.tracebacks[x] + def _do_raise(self, e, i): + tb = self.tracebacks[i] + x = self.index_to_point[i] raise RuntimeError( "An error occured while evaluating " f'"learner.function({x})". ' @@ -162,9 +173,14 @@ def do_log(self): return self.log is not None def _ask(self, n): - points = [ - p for p in self.to_retry.keys() if p not in self.pending_points.values() - ][:n] + points = [] + for i, index in enumerate(self.to_retry.keys()): + if i == n: + break + point = self.index_to_point[index] + if point not in self.pending_points.values(): + points.append(point) + loss_improvements = len(points) * [float("inf")] if len(points) < n: new_points, new_losses = self.learner.ask(n - len(points)) @@ -198,20 +214,22 @@ def overhead(self): def _process_futures(self, done_futs): for fut in done_futs: x = self.pending_points.pop(fut) + i = _key_by_value(self.index_to_point, x) # O(N) try: y = fut.result() t = time.time() - fut.start_time # total execution time except Exception as e: - self.tracebacks[x] = traceback.format_exc() - self.to_retry[x] = self.to_retry.get(x, 0) + 1 - if self.to_retry[x] > self.retries: - self.to_retry.pop(x) + self.tracebacks[i] = traceback.format_exc() + self.to_retry[i] = self.to_retry.get(i, 0) + 1 + if self.to_retry[i] > self.retries: + self.to_retry.pop(i) if self.raise_if_retries_exceeded: - self._do_raise(e, x) + self._do_raise(e, i) else: self._elapsed_function_time += t / self._get_max_tasks() - self.to_retry.pop(x, None) - self.tracebacks.pop(x, None) + self.to_retry.pop(i, None) + self.tracebacks.pop(i, None) + self.index_to_point.pop(i) if self.do_log: self.log.append(("tell", x, y)) self.learner.tell(x, y) @@ -232,6 +250,12 @@ def _get_futures(self): fut = self._submit(x) fut.start_time = start_time self.pending_points[fut] = x + i = _key_by_value(self.index_to_point, x) # O(N) + if i is None: + # `x` is not a value in `self.index_to_point` + self._i += 1 + i = self._i + self.index_to_point[i] = x # Collect and results and add them to the learner futures = list(self.pending_points.keys())