Skip to content

Commit

Permalink
support get_worker() and get_client() in client.run calls
Browse files Browse the repository at this point in the history
Fixes dask#7763
  • Loading branch information
graingert committed Jun 14, 2023
1 parent 74a1bcd commit 9495fba
Show file tree
Hide file tree
Showing 5 changed files with 238 additions and 98 deletions.
54 changes: 48 additions & 6 deletions distributed/nanny.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,14 +16,14 @@
from inspect import isawaitable
from queue import Empty
from time import sleep as sync_sleep
from typing import TYPE_CHECKING, ClassVar, Literal
from typing import TYPE_CHECKING, Any, ClassVar, Literal

from toolz import merge
from tornado.ioloop import IOLoop

import dask
from dask.system import CPU_COUNT
from dask.utils import parse_timedelta
from dask.utils import funcname, parse_timedelta

from distributed import preloading
from distributed.comm import get_address_host
Expand All @@ -42,19 +42,23 @@
from distributed.node import ServerNode
from distributed.process import AsyncProcess
from distributed.proctitle import enable_proctitle_on_children
from distributed.protocol import pickle
from distributed.protocol import pickle, to_serialize
from distributed.protocol.serialize import _is_dumpable
from distributed.security import Security
from distributed.utils import (
convert_args_to_str,
convert_kwargs_to_str,
get_ip,
get_mp_context,
has_arg,
iscoroutinefunction,
json_load_robust,
log_errors,
parse_ports,
silence_logging_cmgr,
wait_for,
)
from distributed.worker import Worker, run
from distributed.worker import Worker
from distributed.worker_memory import (
DeprecatedMemoryManagerAttribute,
DeprecatedMemoryMonitor,
Expand Down Expand Up @@ -500,8 +504,46 @@ async def _():
def is_alive(self):
return self.process is not None and self.process.is_alive()

def run(self, comm, *args, **kwargs):
return run(self, comm, *args, **kwargs)
async def run(
self,
function: bytes,
args: bytes,
kwargs: bytes,
wait: bool = True,
) -> Any:
function_loaded = pickle.loads(function)
is_coro = iscoroutinefunction(function_loaded)
assert wait or is_coro, "Combination not supported"
args_loaded = pickle.loads(args)
kwargs_loaded = pickle.loads(kwargs)
if has_arg(function, "dask_worker"):
kwargs_loaded["dask_worker"] = self

logger.info("Run out-of-band function %r", funcname(function_loaded))

try:
if not is_coro:
result = function_loaded(*args_loaded, **kwargs_loaded)
else:
if wait:
result = await function_loaded(*args_loaded, **kwargs_loaded)
else:
self._ongoing_background_tasks.call_soon(
function_loaded, *args_loaded, **kwargs_loaded
)
result = None

except Exception as e:
logger.warning(
"Run Failed\nFunction: %s\nargs: %s\nkwargs: %s\n",
str(funcname(function))[:1000],
convert_args_to_str(args_loaded, max_len=1000),
convert_kwargs_to_str(kwargs_loaded, max_len=1000),
exc_info=True,
)

return error_message(e)
return {"status": "OK", "result": to_serialize(result)}

def _on_worker_exit_sync(self, exitcode):
try:
Expand Down
53 changes: 44 additions & 9 deletions distributed/scheduler.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,6 +55,7 @@
from dask.utils import (
format_bytes,
format_time,
funcname,
key_split,
parse_bytes,
parse_timedelta,
Expand Down Expand Up @@ -90,6 +91,7 @@
from distributed.multi_lock import MultiLockExtension
from distributed.node import ServerNode
from distributed.proctitle import setproctitle
from distributed.protocol import to_serialize
from distributed.protocol.pickle import dumps, loads
from distributed.protocol.serialize import Serialized, ToPickle, serialize
from distributed.publish import PublishExtension
Expand All @@ -104,9 +106,13 @@
from distributed.utils import (
All,
TimeoutError,
convert_args_to_str,
convert_kwargs_to_str,
empty_context,
format_dashboard_link,
get_fileno_limit,
has_arg,
iscoroutinefunction,
key_split_group,
log_errors,
no_default,
Expand Down Expand Up @@ -7335,12 +7341,11 @@ def get_nbytes(

return result

def run_function(
async def run_function(
self,
comm: Comm,
function: Callable,
args: tuple = (),
kwargs: dict | None = None,
function: bytes,
args: bytes,
kwargs: bytes,
wait: bool = True,
) -> Any:
"""Run a function within this process
Expand All @@ -7349,17 +7354,47 @@ def run_function(
--------
Client.run_on_scheduler
"""
from distributed.worker import run

if not dask.config.get("distributed.scheduler.pickle"):
raise ValueError(
"Cannot run function as the scheduler has been explicitly disallowed from "
"deserializing arbitrary bytestrings using pickle via the "
"'distributed.scheduler.pickle' configuration setting."
)
kwargs = kwargs or {}
self.log_event("all", {"action": "run-function", "function": function})
return run(self, comm, function=function, args=args, kwargs=kwargs, wait=wait)

function_loaded = pickle.loads(function)
is_coro = iscoroutinefunction(function_loaded)
assert wait or is_coro, "Combination not supported"
args_loaded = pickle.loads(args)
kwargs_loaded = pickle.loads(kwargs)
if has_arg(function, "dask_scheduler"):
kwargs_loaded["dask_scheduler"] = self

logger.info("Run out-of-band function %r", funcname(function_loaded))

try:
if not is_coro:
result = function_loaded(*args_loaded, **kwargs_loaded)
else:
if wait:
result = await function_loaded(*args_loaded, **kwargs_loaded)
else:
self._ongoing_background_tasks.call_soon(
function_loaded, *args_loaded, **kwargs_loaded
)
result = None

except Exception as e:
logger.warning(
"Run Failed\nFunction: %s\nargs: %s\nkwargs: %s\n",
str(funcname(function))[:1000],
convert_args_to_str(args_loaded, max_len=1000),
convert_kwargs_to_str(kwargs_loaded, max_len=1000),
exc_info=True,
)

return error_message(e)
return {"status": "OK", "result": to_serialize(result)}

def set_metadata(self, keys: list[str], value: object = None) -> None:
metadata = self.task_metadata
Expand Down
38 changes: 38 additions & 0 deletions distributed/tests/test_client.py
Original file line number Diff line number Diff line change
Expand Up @@ -6624,6 +6624,44 @@ async def f(dask_worker):
assert b.foo == "bar"


@gen_cluster(client=True)
async def test_run_get_worker(c, s, a, b):
def f():
get_worker().foo = "bar"

await c.run(f)

assert a.foo == "bar"
assert b.foo == "bar"


@gen_cluster(client=True)
async def test_run_get_worker_async_def(c, s, a, b):
async def f():
await asyncio.sleep(0.01)
get_worker().foo = "bar"

await c.run(f)

assert a.foo == "bar"
assert b.foo == "bar"


@gen_cluster(client=True)
async def test_run_get_worker_async_def_wait(c, s, a, b):
async def f():
await asyncio.sleep(0.01)
get_worker().foo = "bar"

await c.run(f, wait=False)

while not hasattr(a, "foo") or not hasattr(b, "foo"):
await asyncio.sleep(0.01)

assert a.foo == "bar"
assert b.foo == "bar"


@pytest.mark.slow
@pytest.mark.skipif(WINDOWS, reason="frequently kills off the whole test suite")
@pytest.mark.parametrize("local", [True, False])
Expand Down
39 changes: 39 additions & 0 deletions distributed/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1876,3 +1876,42 @@ async def wait_for(fut: Awaitable[T], timeout: float) -> T:

async def wait_for(fut: Awaitable[T], timeout: float) -> T:
return await asyncio.wait_for(fut, timeout)


def convert_args_to_str(args: tuple[object, ...], max_len: int | None = None) -> str:
"""Convert args to a string, allowing for some arguments to raise
exceptions during conversion and ignoring them.
"""
length = 0
strs = ["" for i in range(len(args))]
for i, arg in enumerate(args):
try:
sarg = repr(arg)
except Exception:
sarg = "< could not convert arg to str >"
strs[i] = sarg
length += len(sarg) + 2
if max_len is not None and length > max_len:
return "({}".format(", ".join(strs[: i + 1]))[:max_len]
else:
return "({})".format(", ".join(strs))


def convert_kwargs_to_str(kwargs: dict, max_len: int | None = None) -> str:
"""Convert kwargs to a string, allowing for some arguments to raise
exceptions during conversion and ignoring them.
"""
length = 0
strs = ["" for i in range(len(kwargs))]
for i, (argname, arg) in enumerate(kwargs.items()):
try:
sarg = repr(arg)
except Exception:
sarg = "< could not convert arg to str >"
skwarg = repr(argname) + ": " + sarg
strs[i] = skwarg
length += len(skwarg) + 2
if max_len is not None and length > max_len:
return "{{{}".format(", ".join(strs[: i + 1]))[:max_len]
else:
return "{{{}}}".format(", ".join(strs))
Loading

0 comments on commit 9495fba

Please sign in to comment.