Skip to content

Commit

Permalink
Attempt to get client from worker in Queue and Variable (#4490)
Browse files Browse the repository at this point in the history
  • Loading branch information
jrbourbeau authored Feb 11, 2021
1 parent 2b9ba97 commit 725f001
Show file tree
Hide file tree
Showing 4 changed files with 57 additions and 12 deletions.
12 changes: 8 additions & 4 deletions distributed/queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,7 +7,7 @@

from .client import Future, Client
from .utils import sync, thread_state
from .worker import get_client
from .worker import get_client, get_worker
from .utils import parse_timedelta

logger = logging.getLogger(__name__)
Expand Down Expand Up @@ -150,8 +150,8 @@ class Queue:
Name used by other clients and the scheduler to identify the queue. If
not given, a random name will be generated.
client: Client (optional)
Client used for communication with the scheduler. Defaults to the
value of ``Client.current()``.
Client used for communication with the scheduler.
If not given, the default global client will be used.
maxsize: int (optional)
Number of items allowed in the queue. If 0 (the default), the queue
size is unbounded.
Expand All @@ -170,7 +170,11 @@ class Queue:
"""

def __init__(self, name=None, client=None, maxsize=0):
self.client = client or Client.current()
try:
self.client = client or Client.current()
except ValueError:
# Initialise new client
self.client = get_worker().client
self.name = name or "queue-" + uuid.uuid4().hex
self._event_started = asyncio.Event()
if self.client.asynchronous or getattr(
Expand Down
21 changes: 20 additions & 1 deletion distributed/tests/test_queues.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

from distributed import Client, Queue, Nanny, worker_client, wait, TimeoutError
from distributed.metrics import time
from distributed.utils_test import gen_cluster, inc, div
from distributed.utils_test import gen_cluster, inc, div, popen
from distributed.utils_test import client, cluster_fixture, loop # noqa: F401


Expand Down Expand Up @@ -276,3 +276,22 @@ def get():
res = c.submit(get)

await c.gather([res, fut])


def test_queue_in_task(loop):
# Ensure that we can create a Queue inside a task on a
# worker in a separate Python process than the client
with popen(["dask-scheduler", "--no-dashboard"]):
with popen(["dask-worker", "127.0.0.1:8786"]):
with Client("tcp://127.0.0.1:8786", loop=loop) as c:
c.wait_for_workers(1)

x = Queue("x")
x.put(123)

def foo():
y = Queue("x")
return y.get()

result = c.submit(foo).result()
assert result == 123
22 changes: 20 additions & 2 deletions distributed/tests/test_variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,9 +10,8 @@
from distributed import Client, Variable, worker_client, Nanny, wait, TimeoutError
from distributed.metrics import time
from distributed.compatibility import WINDOWS
from distributed.utils_test import gen_cluster, inc, div
from distributed.utils_test import gen_cluster, inc, div, captured_logger, popen
from distributed.utils_test import client, cluster_fixture, loop # noqa: F401
from distributed.utils_test import captured_logger


@gen_cluster(client=True)
Expand Down Expand Up @@ -40,6 +39,25 @@ async def test_variable(c, s, a, b):
assert time() < start + 5


def test_variable_in_task(loop):
# Ensure that we can create a Variable inside a task on a
# worker in a separate Python process than the client
with popen(["dask-scheduler", "--no-dashboard"]):
with popen(["dask-worker", "127.0.0.1:8786"]):
with Client("tcp://127.0.0.1:8786", loop=loop) as c:
c.wait_for_workers(1)

x = Variable("x")
x.set(123)

def foo():
y = Variable("x")
return y.get()

result = c.submit(foo).result()
assert result == 123


@gen_cluster(client=True)
async def test_delete_unset_variable(c, s, a, b):
x = Variable()
Expand Down
14 changes: 9 additions & 5 deletions distributed/variable.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,13 +9,13 @@
from dask.utils import stringify
from .client import Future, Client
from .utils import log_errors, TimeoutError, parse_timedelta
from .worker import get_client
from .worker import get_client, get_worker

logger = logging.getLogger(__name__)


class VariableExtension:
"""An extension for the scheduler to manage queues
"""An extension for the scheduler to manage Variables
This adds the following routes to the scheduler
Expand Down Expand Up @@ -145,8 +145,8 @@ class Variable:
Name used by other clients and the scheduler to identify the variable.
If not given, a random name will be generated.
client: Client (optional)
Client used for communication with the scheduler. Defaults to the
value of ``Client.current()``.
Client used for communication with the scheduler.
If not given, the default global client will be used.
Examples
--------
Expand All @@ -165,7 +165,11 @@ class Variable:
"""

def __init__(self, name=None, client=None, maxsize=0):
self.client = client or Client.current()
try:
self.client = client or Client.current()
except ValueError:
# Initialise new client
self.client = get_worker().client
self.name = name or "variable-" + uuid.uuid4().hex

async def _set(self, value):
Expand Down

0 comments on commit 725f001

Please sign in to comment.