Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Recreate actor instances upon worker faliure #4287

Draft
wants to merge 19 commits into
base: main
Choose a base branch
from
Draft
114 changes: 92 additions & 22 deletions distributed/actor.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
import asyncio
import functools
from inspect import iscoroutinefunction
import random
import threading
from queue import Queue
import time
from queue import Queue, Empty

from .client import Future, default_client
from .protocol import to_serialize
Expand Down Expand Up @@ -56,14 +58,19 @@ def __init__(self, cls, address, key, worker=None):
self.key = key
self._future = None
if worker:
# made by a worker
self._worker = worker
self._client = None
assert self.key in self._worker.actors
assert self._address == self._worker.address
else:
try:
# instance on a worker, but not made by worker
self._worker = get_worker()
except ValueError:
self._worker = None
try:
# claim remote original actor
self._client = default_client()
self._future = Future(key)
except ValueError:
Expand All @@ -73,7 +80,7 @@ def __repr__(self):
return "<Actor: %s, key=%s>" % (self._cls.__name__, self.key)

def __reduce__(self):
return (Actor, (self._cls, self._address, self.key))
return Actor, (self._cls, self._address, self.key)

@property
def _io_loop(self):
Expand All @@ -89,6 +96,18 @@ def _scheduler_rpc(self):
else:
return self._client.scheduler

def set_address(self):
(self._address,) = self._sync(
self._scheduler_rpc.find_actor, actor_key=self.key
)
if self._client:
if self._future:
self._future.result()
else:
self._future = Future(self.key)
elif self._worker:
pass

@property
def _worker_rpc(self):
if self._worker:
Expand Down Expand Up @@ -152,33 +171,47 @@ def __getattr__(self, key):
def func(*args, **kwargs):
async def run_actor_function_on_worker():
try:
result = await self._worker_rpc.actor_execute(
function=key,
actor=self.key,
args=[to_serialize(arg) for arg in args],
kwargs={k: to_serialize(v) for k, v in kwargs.items()},
result = await asyncio.wait_for(
self._worker_rpc.actor_execute(
function=key,
actor=self.key,
args=[to_serialize(arg) for arg in args],
kwargs={k: to_serialize(v) for k, v in kwargs.items()},
),
timeout=2,
)
except OSError:
if self._future:
await self._future
else:
raise OSError("Unable to contact Actor's worker")
return result["result"]
except Exception as e:
# assertion error is a low-level comm validation error
result = {"exception": e}

return result

if self._asynchronous:
return asyncio.ensure_future(run_actor_function_on_worker())

async def unwrap():
result = await run_actor_function_on_worker()
if "result" in result:
return result["result"]
raise result["exception"]

return asyncio.ensure_future(unwrap())
else:
# TODO: this mechanism is error prone
# we should endeavor to make dask's standard code work here
q = Queue()

async def wait_then_add_to_queue():
x = await run_actor_function_on_worker()
q.put(x)
try:
x = await run_actor_function_on_worker()
q.put(x)
except Exception as e:
q.put({"exception": e})

self._io_loop.add_callback(wait_then_add_to_queue)

return ActorFuture(q, self._io_loop)
return ActorFuture(
q, self._io_loop, actor=self, defs=(key, args, kwargs)
)

return func

Expand All @@ -188,7 +221,10 @@ async def get_actor_attribute_from_worker():
x = await self._worker_rpc.actor_attribute(
attribute=key, actor=self.key
)
return x["result"]
if "result" in x:
return x["result"]
else:
raise x["exception"]

return self._sync(get_actor_attribute_from_worker)

Expand Down Expand Up @@ -227,21 +263,55 @@ class ActorFuture:
Actor
"""

def __init__(self, q, io_loop, result=None):
def __init__(self, q, io_loop, result=None, actor=None, defs=None):
self.q = q
self.io_loop = io_loop
if result:
self._cached_result = result
else:
self.actor = actor
self.defs = defs

def __await__(self):
return self.result()

def result(self, timeout=None):
def result(self, timeout=2.5, retries=2):
try:
if isinstance(self._cached_result, Exception):
raise self._cached_result
return self._cached_result
except AttributeError:
self._cached_result = self.q.get(timeout=timeout)
return self._cached_result
pass
try:
out = self.q.get(timeout=timeout)
if "result" in out:
self._cached_result = out["result"]
else:
ex = out["exception"]
if retries > 0:
self._reset()
return self.result(retries=retries - 1)
self._cached_result = ex
except Empty:
if retries > 0:
self._reset()
return self.result(retries=retries - 1)
self._cached_result = TimeoutError()
self.actor = None
self.defs = None
return self.result()

def _reset(self):
time.sleep(random.random() + 0.1)
self.actor.set_address()
attr, args, kwargs = self.defs
if args is not None:
# method
ac2 = getattr(self.actor, attr)(*args, **kwargs)
else:
# attribute
ac2 = getattr(self.actor, attr)
self.__dict__.update(ac2.__dict__)

def __repr__(self):
return "<ActorFuture>"
3 changes: 3 additions & 0 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -1804,6 +1804,9 @@ def map(

return [futures[stringify(k)] for k in keys]

def find_actor(self, ac):
return self.sync(self.scheduler.find_actor, actor_key=ac.key)

async def _gather(self, futures, errors="raise", direct=None, local_worker=None):
unpacked, future_set = unpack_remotedata(futures, byte_keys=True)
keys = [stringify(future.key) for future in future_set]
Expand Down
20 changes: 19 additions & 1 deletion distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -262,7 +262,7 @@ class WorkerState:
.. attribute:: processing: {TaskState: cost}

A dictionary of tasks that have been submitted to this worker.
Each task state is asssociated with the expected cost in seconds
Each task state is associated with the expected cost in seconds
of running that task, summing both the task's expected computation
time and the expected communication time of its result.

Expand Down Expand Up @@ -1839,6 +1839,7 @@ def __init__(
"ncores": self.get_ncores,
"has_what": self.get_has_what,
"who_has": self.get_who_has,
"find_actor": self.find_actor,
"processing": self.get_processing,
"call_stack": self.get_call_stack,
"profile": self.get_profile,
Expand Down Expand Up @@ -4368,6 +4369,23 @@ def get_processing(self, comm=None, workers=None):
w: [ts._key for ts in ws._processing] for w, ws in self.workers.items()
}

def find_actor(self, _, actor_key=None):
assert self.tasks[actor_key].actor
if actor_key in self.tasks:
workers = self.tasks[actor_key].who_has
out = [
ws.address for ws in workers if actor_key in [a.key for a in ws.actors]
]
if not out:
worker = random.choice(
[ws.address for ws in self.tasks[actor_key].who_has] or self.workers
)
self.send_task_to_worker(worker, actor_key)
out = [worker]
else:
out = []
return out

def get_who_has(self, comm=None, keys=None):
ws: WorkerState
ts: TaskState
Expand Down
115 changes: 115 additions & 0 deletions distributed/tests/test_actor.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import operator
import sys
from time import sleep

import pytest
Expand All @@ -9,13 +10,15 @@
from distributed.utils_test import cluster, gen_cluster
from distributed.utils_test import client, cluster_fixture, loop # noqa: F401
from distributed.metrics import time
from distributed.worker import get_worker


class Counter:
n = 0

def __init__(self):
self.n = 0
self.should_kill = False

def increment(self):
self.n += 1
Expand All @@ -29,6 +32,22 @@ def add(self, x):
self.n += x
return self.n

def set_kill(self):
self.should_kill = True
return True

def kill_now(self):
if self.should_kill:
sys.exit()
return True

def kill_soon(self):
if self.should_kill:
loop = get_worker().loop
loop.add_callback(sys.exit)
return False
return True


class UsesCounter:
# An actor whose method argument is another actor
Expand All @@ -40,6 +59,20 @@ async def ado_inc(self, ac):
return await ac.ainc()


class UsesCounterInit:
# An actor whose init argument is another actor
# and saves the reference

def __init__(self, ac):
self.ac = ac

def do_inc(self):
return self.ac.increment().result()

async def ado_inc(self):
return await self.ac.ainc()


class List:
L = []

Expand Down Expand Up @@ -581,3 +614,85 @@ async def test_async_deadlock(client, s, a, b):
ac2 = await client.submit(UsesCounter, actor=True, workers=[ac._address])

assert (await ac2.ado_inc(ac)) == 1


def test_exception():
class MyException(Exception):
pass

class Broken:
def method(self):
raise MyException

@property
def prop(self):
raise MyException

with cluster(nworkers=2) as (cl, w):
client = Client(cl["address"])
ac = client.submit(Broken, actor=True).result()
acfut = ac.method()
with pytest.raises(MyException):
acfut.result()

with pytest.raises(MyException):
ac.prop


def test_actor_retire():
# for the graceful movement of actor from one worker to another
with cluster(nworkers=3) as (cl, w):
client = Client(cl["address"])
# each actor goes to a different worker by default, but worker holding ac3
# will also hold a reference to ac
ac = client.submit(Counter, actor=True, workers=[w[0]["address"]]).result()
ac2 = client.submit(UsesCounter, actor=True, workers=[w[1]["address"]]).result()
ac3 = client.submit(
UsesCounterInit, ac, actor=True, workers=[w[2]["address"]]
).result()
assert ac.increment().result() == 1
assert ac2.do_inc(ac).result() == 2
assert ac3.do_inc().result() == 3

to_retire = ac._address
client.retire_workers([to_retire])

# counter value has reset to zero
assert ac.increment().result() == 1
assert ac2.do_inc(ac).result() == 2
# on this one, the remote copy also needs to reset its address
assert ac3.do_inc().result() == 3

# for cleanup
w[:] = [_ for _ in w if _["address"] != to_retire]
del ac, ac2, ac3


def test_actor_kill():
# for the graceful movement of actor from one worker to another
with cluster(nworkers=3) as (cl, w):
client = Client(cl["address"])
# each actor goes to a different worker by default, but worker holding ac3
# will also hold a reference to ac
ac = client.submit(Counter, actor=True, workers=[w[0]["address"]]).result()
ac2 = client.submit(UsesCounter, actor=True, workers=[w[1]["address"]]).result()
ac3 = client.submit(
UsesCounterInit, ac, actor=True, workers=[w[2]["address"]]
).result()
assert ac.increment().result() == 1
assert ac2.do_inc(ac).result() == 2
assert ac3.do_inc().result() == 3

assert ac.set_kill().result()
assert ac.kill_now().result()

# counter value has reset to zero
assert ac.increment().result() == 1
assert ac2.do_inc(ac).result() == 2
# on this one, the remote copy also needs to reset its address
print(ac3.do_inc().result())
print("DONE")

# for cleanup
w.pop(0)
del ac, ac2, ac3
Loading