Skip to content

Commit

Permalink
make the Runner work with unhashable points
Browse files Browse the repository at this point in the history
  • Loading branch information
basnijholt committed Apr 16, 2020
1 parent 3d9397d commit 69b97f5
Showing 1 changed file with 36 additions and 12 deletions.
48 changes: 36 additions & 12 deletions adaptive/runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,6 +54,12 @@
)


def _key_by_value(dct, value):
for k, v in dct.items():
if v == value:
return k


class BaseRunner(metaclass=abc.ABCMeta):
r"""Base class for runners that use `concurrent.futures.Executors`.
Expand Down Expand Up @@ -146,11 +152,16 @@ def __init__(
self.to_retry = {}
self.tracebacks = {}

# Keeping track of index -> point
self.index_to_point = {}
self._i = 0 # some unique index to be associated with each point

def _get_max_tasks(self):
return self._max_tasks or _get_ncores(self.executor)

def _do_raise(self, e, x):
tb = self.tracebacks[x]
def _do_raise(self, e, i):
tb = self.tracebacks[i]
x = self.index_to_point[i]
raise RuntimeError(
"An error occured while evaluating "
f'"learner.function({x})". '
Expand All @@ -162,9 +173,14 @@ def do_log(self):
return self.log is not None

def _ask(self, n):
points = [
p for p in self.to_retry.keys() if p not in self.pending_points.values()
][:n]
points = []
for i, index in enumerate(self.to_retry.keys()):
if i == n:
break
point = self.index_to_point[index]
if point not in self.pending_points.values():
points.append(point)

loss_improvements = len(points) * [float("inf")]
if len(points) < n:
new_points, new_losses = self.learner.ask(n - len(points))
Expand Down Expand Up @@ -198,20 +214,22 @@ def overhead(self):
def _process_futures(self, done_futs):
for fut in done_futs:
x = self.pending_points.pop(fut)
i = _key_by_value(self.index_to_point, x) # O(N)
try:
y = fut.result()
t = time.time() - fut.start_time # total execution time
except Exception as e:
self.tracebacks[x] = traceback.format_exc()
self.to_retry[x] = self.to_retry.get(x, 0) + 1
if self.to_retry[x] > self.retries:
self.to_retry.pop(x)
self.tracebacks[i] = traceback.format_exc()
self.to_retry[i] = self.to_retry.get(i, 0) + 1
if self.to_retry[i] > self.retries:
self.to_retry.pop(i)
if self.raise_if_retries_exceeded:
self._do_raise(e, x)
self._do_raise(e, i)
else:
self._elapsed_function_time += t / self._get_max_tasks()
self.to_retry.pop(x, None)
self.tracebacks.pop(x, None)
self.to_retry.pop(i, None)
self.tracebacks.pop(i, None)
self.index_to_point.pop(i)
if self.do_log:
self.log.append(("tell", x, y))
self.learner.tell(x, y)
Expand All @@ -232,6 +250,12 @@ def _get_futures(self):
fut = self._submit(x)
fut.start_time = start_time
self.pending_points[fut] = x
i = _key_by_value(self.index_to_point, x) # O(N)
if i is None:
# `x` is not a value in `self.index_to_point`
self._i += 1
i = self._i
self.index_to_point[i] = x

# Collect and results and add them to the learner
futures = list(self.pending_points.keys())
Expand Down

0 comments on commit 69b97f5

Please sign in to comment.