diff --git a/adaptive/runner.py b/adaptive/runner.py index ccf5ac4cd..fa96096ca 100644 --- a/adaptive/runner.py +++ b/adaptive/runner.py @@ -175,10 +175,10 @@ def do_log(self): def _ask(self, n): points = [] - for i, _id in enumerate(self._to_retry.keys()): + for i, pid in enumerate(self._to_retry.keys()): if i == n: break - point = self._id_to_point[_id] + point = self._id_to_point[pid] if point not in self.pending_points.values(): points.append(point) @@ -187,6 +187,8 @@ def _ask(self, n): new_points, new_losses = self.learner.ask(n - len(points)) points += new_points loss_improvements += new_losses + for p in new_points: + self._id_to_point[self._next_id()] = p return points, loss_improvements def overhead(self): @@ -251,11 +253,6 @@ def _get_futures(self): fut = self._submit(x) fut.start_time = start_time self.pending_points[fut] = x - try: - _id = _key_by_value(self._id_to_point, x) # O(N) - except StopIteration: # `x` is not a value in `self._id_to_point` - _id = self._next_id() - self._id_to_point[_id] = x # Collect and results and add them to the learner futures = list(self.pending_points.keys())