Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Future deserialization without available client #7580

Merged
merged 11 commits into from
Mar 21, 2023
24 changes: 16 additions & 8 deletions distributed/actor.py
Original file line number Diff line number Diff line change
Expand Up @@ -95,20 +95,22 @@ def __init__(self, cls, address, key, worker=None):
super().__init__(key)
self._cls = cls
self._address = address
self._key = key
self._future = None
if worker:
self._worker = worker
self._client = None
else:
self._worker = worker
self._client = None
self._try_bind_worker_client()

def _try_bind_worker_client(self):
if not self._worker:
try:
# TODO: `get_worker` may return the wrong worker instance for async local clusters (most tests)
# when run outside of a task (when deserializing a key pointing to an Actor, etc.)
self._worker = get_worker()
except ValueError:
self._worker = None
if not self._client:
try:
self._client = get_client()
self._future = Future(key, inform=self._worker is None)
self._future = Future(self._key, inform=False)
# ^ When running on a worker, only hold a weak reference to the key, otherwise the key could become unreleasable.
except ValueError:
self._client = None
Expand All @@ -121,20 +123,26 @@ def __reduce__(self):

@property
def _io_loop(self):
if self._worker is None and self._client is None:
self._try_bind_worker_client()
if self._worker:
return self._worker.loop
else:
return self._client.loop

@property
def _scheduler_rpc(self):
if self._worker is None and self._client is None:
self._try_bind_worker_client()
if self._worker:
return self._worker.scheduler
else:
return self._client.scheduler

@property
def _worker_rpc(self):
if self._worker is None and self._client is None:
self._try_bind_worker_client()
if self._worker:
return self._worker.rpc(self._address)
else:
Expand Down Expand Up @@ -168,7 +176,7 @@ def __getattr__(self, key):
raise ValueError(
"Worker holding Actor was lost. Status: " + self._future.status
)

self._try_bind_worker_client()
if (
self._worker
and self._worker.address == self._address
Expand Down
97 changes: 57 additions & 40 deletions distributed/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,32 +198,58 @@ class Future(WrappedKey):
def __init__(self, key, client=None, inform=True, state=None):
self.key = key
self._cleared = False
tkey = stringify(key)
self.client = client or Client.current()
self.client._inc_ref(tkey)
self._generation = self.client.generation
self._tkey = stringify(key)
self._client = client
self._input_state = state
self._inform = inform
self._state = None
self._bind_late()

if tkey in self.client.futures:
self._state = self.client.futures[tkey]
else:
self._state = self.client.futures[tkey] = FutureState()

if inform:
self.client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [stringify(key)],
"client": self.client.id,
}
)
@property
def client(self):
self._bind_late()
return self._client

if state is not None:
def _bind_late(self):
if not self._client:
try:
handler = self.client._state_handlers[state]
except KeyError:
pass
client = get_client()
except ValueError:
client = None
self._client = client
if self._client and not self._state:
self._client._inc_ref(self._tkey)
self._generation = self._client.generation

if self._tkey in self._client.futures:
self._state = self._client.futures[self._tkey]
else:
handler(key=key)
self._state = self._client.futures[self._tkey] = FutureState()

if self._inform:
self._client._send_to_scheduler(
{
"op": "client-desires-keys",
"keys": [self._tkey],
"client": self._client.id,
}
)

if self._input_state is not None:
try:
handler = self._client._state_handlers[self._input_state]
except KeyError:
pass
else:
handler(key=self.key)

def _verify_initialized(self):
if not self.client or not self._state:
raise RuntimeError(
f"{type(self)} object not properly initialized. This can happen"
" if the object is being deserialized outside of the context of"
" a Client or Worker."
)

@property
def executor(self):
Expand Down Expand Up @@ -277,6 +303,7 @@ def result(self, timeout=None):
result
The result of the computation. Or a coroutine if the client is asynchronous.
"""
self._verify_initialized()
if self.client.asynchronous:
return self.client.sync(self._result, callback_timeout=timeout)

Expand Down Expand Up @@ -338,6 +365,7 @@ def exception(self, timeout=None, **kwargs):
--------
Future.traceback
"""
self._verify_initialized()
return self.client.sync(self._exception, callback_timeout=timeout, **kwargs)

def add_done_callback(self, fn):
Expand All @@ -354,6 +382,7 @@ def add_done_callback(self, fn):
fn : callable
The method or function to be called
"""
self._verify_initialized()
cls = Future
if cls._cb_executor is None or cls._cb_executor_pid != os.getpid():
try:
Expand Down Expand Up @@ -381,6 +410,7 @@ def cancel(self, **kwargs):
--------
Client.cancel
"""
self._verify_initialized()
return self.client.cancel([self], **kwargs)

def retry(self, **kwargs):
Expand All @@ -390,6 +420,7 @@ def retry(self, **kwargs):
--------
Client.retry
"""
self._verify_initialized()
return self.client.retry([self], **kwargs)

def cancelled(self):
Expand Down Expand Up @@ -440,6 +471,7 @@ def traceback(self, timeout=None, **kwargs):
--------
Future.exception
"""
self._verify_initialized()
return self.client.sync(self._traceback, callback_timeout=timeout, **kwargs)

@property
Expand All @@ -454,31 +486,16 @@ def release(self):
This method can be called from different threads
(see e.g. Client.get() or Future.__del__())
"""
self._verify_initialized()
if not self._cleared and self.client.generation == self._generation:
self._cleared = True
try:
self.client.loop.add_callback(self.client._dec_ref, stringify(self.key))
except TypeError: # pragma: no cover
pass # Shutting down, add_callback may be None

def __getstate__(self):
return self.key, self.client.scheduler.address

def __setstate__(self, state):
key, address = state
try:
c = Client.current(allow_global=False)
except ValueError:
c = get_client(address)
self.__init__(key, c)
c._send_to_scheduler(
{
"op": "update-graph",
"tasks": {},
"keys": [stringify(self.key)],
"client": c.id,
}
)
def __reduce__(self) -> str | tuple[Any, ...]:
return Future, (self.key,)

def __del__(self):
try:
Expand Down
8 changes: 7 additions & 1 deletion distributed/diagnostics/plugin.py
Original file line number Diff line number Diff line change
Expand Up @@ -299,7 +299,13 @@ async def setup(self, worker):
from distributed.semaphore import Semaphore

async with (
await Semaphore(max_leases=1, name=socket.gethostname(), register=True)
await Semaphore(
max_leases=1,
name=socket.gethostname(),
register=True,
scheduler_rpc=worker.scheduler,
loop=worker.loop,
)
):
if not await self._is_installed(worker):
logger.info(
Expand Down
14 changes: 10 additions & 4 deletions distributed/event.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,12 +178,18 @@ class Event:
"""

def __init__(self, name=None, client=None):
try:
self.client = client or get_client()
except ValueError:
self.client = None
self._client = client
self.name = name or "event-" + uuid.uuid4().hex

@property
def client(self):
if not self._client:
try:
self._client = get_client()
except ValueError:
pass
return self._client

def __await__(self):
"""async constructor

Expand Down
28 changes: 21 additions & 7 deletions distributed/lock.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,9 +7,8 @@

from dask.utils import parse_timedelta

from distributed.client import Client
from distributed.utils import TimeoutError, log_errors, wait_for
from distributed.worker import get_worker
from distributed.worker import get_client

logger = logging.getLogger(__name__)

Expand Down Expand Up @@ -95,15 +94,28 @@ class Lock:
"""

def __init__(self, name=None, client=None):
try:
self.client = client or Client.current()
except ValueError:
# Initialise new client
self.client = get_worker().client
self._client = client
self.name = name or "lock-" + uuid.uuid4().hex
self.id = uuid.uuid4().hex
self._locked = False

@property
def client(self):
if not self._client:
try:
self._client = get_client()
except ValueError:
pass
return self._client

def _verify_running(self):
if not self.client:
raise RuntimeError(
f"{type(self)} object not properly initialized. This can happen"
" if the object is being deserialized outside of the context of"
" a Client or Worker."
)

def acquire(self, blocking=True, timeout=None):
"""Acquire the lock

Expand All @@ -127,6 +139,7 @@ def acquire(self, blocking=True, timeout=None):
-------
True or False whether or not it successfully acquired the lock
"""
self._verify_running()
timeout = parse_timedelta(timeout)

if not blocking:
Expand All @@ -145,6 +158,7 @@ def acquire(self, blocking=True, timeout=None):

def release(self):
"""Release the lock if already acquired"""
self._verify_running()
if not self.locked():
raise ValueError("Lock is not yet acquired")
result = self.client.sync(
Expand Down
15 changes: 10 additions & 5 deletions distributed/profile.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,15 +155,20 @@ def process(
merge
"""
if depth is None:
depth = sys.getrecursionlimit() - 50
if depth <= 0:
return None
Comment on lines -159 to -160
Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this check caused us to not collect any information in these cases. I think it's still valuable to get a snapshot through even if it's not the lowest frame. Moving the depth check further below achieves this (see test_profile.py)

# Cut off rather conservatively since the output of the profiling
# sometimes need to be recursed into as well, e.g. for serialization
# which can cause recursion errors later on since this can generate
# deeply nested dictionaries
depth = min(250, sys.getrecursionlimit() // 4)

if any(frame.f_code.co_filename.endswith(o) for o in omit):
return None

prev = frame.f_back
if prev is not None and (
stop is None or not prev.f_code.co_filename.endswith(stop)
if (
depth > 0
and prev is not None
and (stop is None or not prev.f_code.co_filename.endswith(stop))
):
new_state = process(prev, frame, state, stop=stop, depth=depth - 1)
if new_state is None:
Expand Down
21 changes: 21 additions & 0 deletions distributed/protocol/tests/test_protocol.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
from __future__ import annotations

import sys
from threading import Thread
from time import sleep

Expand All @@ -8,6 +9,7 @@
import dask
from dask.sizeof import sizeof

from distributed.compatibility import WINDOWS
from distributed.protocol import dumps, loads, maybe_compress, msgpack, to_serialize
from distributed.protocol.compression import (
compressions,
Expand Down Expand Up @@ -412,3 +414,22 @@ def test_sizeof_serialize(Wrapper, Wrapped):
assert size <= sizeof(ser_obj) < size * 1.05
serialized = Wrapped(*serialize(ser_obj))
assert size <= sizeof(serialized) < size * 1.05


@pytest.mark.skipif(WINDOWS, reason="On windows this is triggering a stackoverflow")
def test_deeply_nested_structures():
# These kind of deeply nested structures are generated in our profiling code
def gen_deeply_nested(depth):
msg = {}
d = msg
while depth:
depth -= 1
d["children"] = d = {}
return msg

msg = gen_deeply_nested(sys.getrecursionlimit() - 100)
with pytest.raises(TypeError, match="Could not serialize object"):
serialize(msg, on_error="raise")

msg = gen_deeply_nested(sys.getrecursionlimit() // 4)
assert isinstance(serialize(msg), tuple)
fjetter marked this conversation as resolved.
Show resolved Hide resolved
Loading