Skip to content

Commit

Permalink
Future deserialization without available client (#7580)
Browse files Browse the repository at this point in the history
  • Loading branch information
fjetter authored Mar 21, 2023
1 parent 1b92553 commit 1b34a5b
Show file tree
Hide file tree
Showing 19 changed files with 302 additions and 154 deletions.
24 changes: 16 additions & 8 deletions distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,22 @@ def __init__(self, cls, address, key, worker=None):
super().__init__(key)
self._cls = cls
self._address = address
self._key = key
self._future = None
if worker:
self._worker = worker
self._client = None
else:
self._worker = worker
self._client = None
self._try_bind_worker_client()

def _try_bind_worker_client(self):
if not self._worker:
try:
# TODO: `get_worker` may return the wrong worker instance for async local clusters (most tests)
# when run outside of a task (when deserializing a key pointing to an Actor, etc.)
self._worker = get_worker()
except ValueError:
self._worker = None
if not self._client:
try:
self._client = get_client()
self._future = Future(key, inform=self._worker is None)
self._future = Future(self._key, inform=False)
# ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable.
except ValueError:
self._client = None
Expand All @@ -121,20 +123,26 @@ def __reduce__(self):

@property
def _io_loop(self):
if self._worker is None and self._client is None:
self._try_bind_worker_client()
if self._worker:
return self._worker.loop
else:
return self._client.loop

@property
def _scheduler_rpc(self):
if self._worker is None and self._client is None:
self._try_bind_worker_client()
if self._worker:
return self._worker.scheduler
else:
return self._client.scheduler

@property
def _worker_rpc(self):
if self._worker is None and self._client is None:
self._try_bind_worker_client()
if self._worker:
return self._worker.rpc(self._address)
else:
Expand Down Expand Up @@ -168,7 +176,7 @@ def __getattr__(self, key):
raise ValueError(
"Worker holding Actor was lost. Status: " + self._future.status
)

self._try_bind_worker_client()
if (
self._worker
and self._worker.address == self._address
Expand Down
97 changes: 57 additions & 40 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,32 +198,58 @@ class Future(WrappedKey):
def __init__(self, key, client=None, inform=True, state=None):
self.key = key
self._cleared = False
tkey = stringify(key)
self.client = client or Client.current()
self.client._inc_ref(tkey)
self._generation = self.client.generation
self._tkey = stringify(key)
self._client = client
self._input_state = state
self._inform = inform
self._state = None
self._bind_late()

if tkey in self.client.futures:
self._state = self.client.futures[tkey]
else:
self._state = self.client.futures[tkey] = FutureState()

if inform:
self.client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [stringify(key)],
"client": self.client.id,
}
)
@property
def client(self):
self._bind_late()
return self._client

if state is not None:
def _bind_late(self):
if not self._client:
try:
handler = self.client._state_handlers[state]
except KeyError:
pass
client = get_client()
except ValueError:
client = None
self._client = client
if self._client and not self._state:
self._client._inc_ref(self._tkey)
self._generation = self._client.generation

if self._tkey in self._client.futures:
self._state = self._client.futures[self._tkey]
else:
handler(key=key)
self._state = self._client.futures[self._tkey] = FutureState()

if self._inform:
self._client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [self._tkey],
"client": self._client.id,
}
)

if self._input_state is not None:
try:
handler = self._client._state_handlers[self._input_state]
except KeyError:
pass
else:
handler(key=self.key)

def _verify_initialized(self):
if not self.client or not self._state:
raise RuntimeError(
f"{type(self)} object not properly initialized. This can happen"
" if the object is being deserialized outside of the context of"
" a Client or Worker."
)

@property
def executor(self):
Expand Down Expand Up @@ -277,6 +303,7 @@ def result(self, timeout=None):
result
The result of the computation. Or a coroutine if the client is asynchronous.
"""
self._verify_initialized()
if self.client.asynchronous:
return self.client.sync(self._result, callback_timeout=timeout)

Expand Down Expand Up @@ -338,6 +365,7 @@ def exception(self, timeout=None, **kwargs):
--------
Future.traceback
"""
self._verify_initialized()
return self.client.sync(self._exception, callback_timeout=timeout, **kwargs)

def add_done_callback(self, fn):
Expand All @@ -354,6 +382,7 @@ def add_done_callback(self, fn):
fn : callable
The method or function to be called
"""
self._verify_initialized()
cls = Future
if cls._cb_executor is None or cls._cb_executor_pid != os.getpid():
try:
Expand Down Expand Up @@ -381,6 +410,7 @@ def cancel(self, **kwargs):
--------
Client.cancel
"""
self._verify_initialized()
return self.client.cancel([self], **kwargs)

def retry(self, **kwargs):
Expand All @@ -390,6 +420,7 @@ def retry(self, **kwargs):
--------
Client.retry
"""
self._verify_initialized()
return self.client.retry([self], **kwargs)

def cancelled(self):
Expand Down Expand Up @@ -440,6 +471,7 @@ def traceback(self, timeout=None, **kwargs):
--------
Future.exception
"""
self._verify_initialized()
return self.client.sync(self._traceback, callback_timeout=timeout, **kwargs)

@property
Expand All @@ -454,31 +486,16 @@ def release(self):
This method can be called from different threads
(see e.g. Client.get() or Future.__del__())
"""
self._verify_initialized()
if not self._cleared and self.client.generation == self._generation:
self._cleared = True
try:
self.client.loop.add_callback(self.client._dec_ref, stringify(self.key))
except TypeError: # pragma: no cover
pass # Shutting down, add_callback may be None

def __getstate__(self):
return self.key, self.client.scheduler.address

def __setstate__(self, state):
key, address = state
try:
c = Client.current(allow_global=False)
except ValueError:
c = get_client(address)
self.__init__(key, c)
c._send_to_scheduler(
{
"op": "update-graph",
"tasks": {},
"keys": [stringify(self.key)],
"client": c.id,
}
)
def __reduce__(self) -> str | tuple[Any, ...]:
return Future, (self.key,)

def __del__(self):
try:
Expand Down
8 changes: 7 additions & 1 deletion distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,13 @@ async def setup(self, worker):
from distributed.semaphore import Semaphore

async with (
await Semaphore(max_leases=1, name=socket.gethostname(), register=True)
await Semaphore(
max_leases=1,
name=socket.gethostname(),
register=True,
scheduler_rpc=worker.scheduler,
loop=worker.loop,
)
):
if not await self._is_installed(worker):
logger.info(
Expand Down
14 changes: 10 additions & 4 deletions distributed/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,18 @@ class Event:
"""

def __init__(self, name=None, client=None):
try:
self.client = client or get_client()
except ValueError:
self.client = None
self._client = client
self.name = name or "event-" + uuid.uuid4().hex

@property
def client(self):
if not self._client:
try:
self._client = get_client()
except ValueError:
pass
return self._client

def __await__(self):
"""async constructor
Expand Down
28 changes: 21 additions & 7 deletions distributed/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

from dask.utils import parse_timedelta

from distributed.client import Client
from distributed.utils import TimeoutError, log_errors, wait_for
from distributed.worker import get_worker
from distributed.worker import get_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -95,15 +94,28 @@ class Lock:
"""

def __init__(self, name=None, client=None):
try:
self.client = client or Client.current()
except ValueError:
# Initialise new client
self.client = get_worker().client
self._client = client
self.name = name or "lock-" + uuid.uuid4().hex
self.id = uuid.uuid4().hex
self._locked = False

@property
def client(self):
if not self._client:
try:
self._client = get_client()
except ValueError:
pass
return self._client

def _verify_running(self):
if not self.client:
raise RuntimeError(
f"{type(self)} object not properly initialized. This can happen"
" if the object is being deserialized outside of the context of"
" a Client or Worker."
)

def acquire(self, blocking=True, timeout=None):
"""Acquire the lock
Expand All @@ -127,6 +139,7 @@ def acquire(self, blocking=True, timeout=None):
-------
True or False whether or not it successfully acquired the lock
"""
self._verify_running()
timeout = parse_timedelta(timeout)

if not blocking:
Expand All @@ -145,6 +158,7 @@ def acquire(self, blocking=True, timeout=None):

def release(self):
"""Release the lock if already acquired"""
self._verify_running()
if not self.locked():
raise ValueError("Lock is not yet acquired")
result = self.client.sync(
Expand Down
15 changes: 10 additions & 5 deletions distributed/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,20 @@ def process(
merge
"""
if depth is None:
depth = sys.getrecursionlimit() - 50
if depth <= 0:
return None
# Cut off rather conservatively since the output of the profiling
# sometimes need to be recursed into as well, e.g. for serialization
# which can cause recursion errors later on since this can generate
# deeply nested dictionaries
depth = min(250, sys.getrecursionlimit() // 4)

if any(frame.f_code.co_filename.endswith(o) for o in omit):
return None

prev = frame.f_back
if prev is not None and (
stop is None or not prev.f_code.co_filename.endswith(stop)
if (
depth > 0
and prev is not None
and (stop is None or not prev.f_code.co_filename.endswith(stop))
):
new_state = process(prev, frame, state, stop=stop, depth=depth - 1)
if new_state is None:
Expand Down
21 changes: 21 additions & 0 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
from threading import Thread
from time import sleep

Expand All @@ -8,6 +9,7 @@
import dask
from dask.sizeof import sizeof

from distributed.compatibility import WINDOWS
from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize
from distributed.protocol.compression import (
compressions,
Expand Down Expand Up @@ -412,3 +414,22 @@ def test_sizeof_serialize(Wrapper, Wrapped):
assert size <= sizeof(ser_obj) < size * 1.05
serialized = Wrapped(*serialize(ser_obj))
assert size <= sizeof(serialized) < size * 1.05


@pytest.mark.skipif(WINDOWS, reason="On windows this is triggering a stackoverflow")
def test_deeply_nested_structures():
# These kind of deeply nested structures are generated in our profiling code
def gen_deeply_nested(depth):
msg = {}
d = msg
while depth:
depth -= 1
d["children"] = d = {}
return msg

msg = gen_deeply_nested(sys.getrecursionlimit() - 100)
with pytest.raises(TypeError, match="Could not serialize object"):
serialize(msg, on_error="raise")

msg = gen_deeply_nested(sys.getrecursionlimit() // 4)
assert isinstance(serialize(msg), tuple)
Loading

0 comments on commit 1b34a5b

Please sign in to comment.