Skip to content

Commit

Permalink
Merge branch 'branch-0.20' into update-docs
Browse files Browse the repository at this point in the history
  • Loading branch information
charlesbluca authored Apr 12, 2021
2 parents 4baa15d + 1526017 commit 65e47f7
Show file tree
Hide file tree
Showing 5 changed files with 100 additions and 29 deletions.
9 changes: 0 additions & 9 deletions dask_cuda/cli/dask_cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,13 +145,6 @@
available). Can be a string (like ``"eth0"`` for NVLink or ``"ib0"`` for
InfiniBand) or ``None`` to fall back on the default interface.""",
)
@click.option(
"--death-timeout",
type=str,
default=None,
help="""Amount of time to wait for a scheduler before closing. Can be a string (like
``"3s"``, ``"3.5 seconds"``, or ``"300ms"``) or ``None`` to disable timeout.""",
)
@click.option(
"--preload",
type=str,
Expand Down Expand Up @@ -260,7 +253,6 @@ def main(
local_directory,
scheduler_file,
interface,
death_timeout,
preload,
dashboard_prefix,
tls_ca_file,
Expand Down Expand Up @@ -305,7 +297,6 @@ def main(
local_directory,
scheduler_file,
interface,
death_timeout,
preload,
dashboard_prefix,
security,
Expand Down
1 change: 0 additions & 1 deletion dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,6 @@ def __init__(
local_directory=None,
scheduler_file=None,
interface=None,
death_timeout=None,
preload=[],
dashboard_prefix=None,
security=None,
Expand Down
59 changes: 46 additions & 13 deletions dask_cuda/explicit_comms/comms.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
import asyncio
import concurrent.futures
import contextlib
import time
import uuid
from typing import List, Optional
Expand All @@ -11,6 +12,33 @@
_default_comms = None


def get_multi_lock_or_null_context(multi_lock_context, *args, **kwargs):
"""Return either a MultiLock or a NULL context
Parameters
----------
multi_lock_context: bool
If True return MultiLock context else return a NULL context that
doesn't do anything
*args, **kwargs:
Arguments parsed to the MultiLock creation
Returns
-------
context: context
Either `MultiLock(*args, **kwargs)` or a NULL context
"""
if multi_lock_context:
from distributed import MultiLock

return MultiLock(*args, **kwargs)
else:
# Use a null context that doesn't do anything
# TODO: use `contextlib.nullcontext()` from Python 3.7+
return contextlib.suppress()


def default_comms(client: Optional[Client] = None) -> "CommsContext":
"""Return the default comms object
Expand Down Expand Up @@ -194,7 +222,7 @@ def submit(self, worker, coroutine, *args, wait=False):
)
return ret.result() if wait else ret

def run(self, coroutine, *args, workers=None):
def run(self, coroutine, *args, workers=None, lock_workers=False):
"""Run a coroutine on multiple workers
The coroutine is given the worker's state dict as the first argument
Expand All @@ -208,6 +236,9 @@ def run(self, coroutine, *args, workers=None):
Arguments for `coroutine`
workers: list, optional
List of workers. Default is all workers
lock_workers: bool, optional
Use distributed.MultiLock to get exclusive access to the workers. Use
this flag to support parallel runs.
Returns
-------
Expand All @@ -216,16 +247,18 @@ def run(self, coroutine, *args, workers=None):
"""
if workers is None:
workers = self.worker_addresses
ret = []
for worker in workers:
ret.append(
self.client.submit(
_run_coroutine_on_worker,
self.sessionId,
coroutine,
args,
workers=[worker],
pure=False,

with get_multi_lock_or_null_context(lock_workers, workers):
ret = []
for worker in workers:
ret.append(
self.client.submit(
_run_coroutine_on_worker,
self.sessionId,
coroutine,
args,
workers=[worker],
pure=False,
)
)
)
return self.client.gather(ret)
return self.client.gather(ret)
58 changes: 54 additions & 4 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
import asyncio
import multiprocessing as mp

import numpy as np
Expand All @@ -7,7 +8,7 @@
import dask
from dask import dataframe as dd
from dask.dataframe.shuffle import partitioning_index
from distributed import Client
from distributed import Client, get_worker
from distributed.deploy.local import LocalCluster

import dask_cuda
Expand All @@ -22,8 +23,8 @@
# that UCX options of the different tests doesn't conflict.


async def my_rank(state):
return state["rank"]
async def my_rank(state, arg):
return state["rank"] + arg


def _test_local_cluster(protocol):
Expand All @@ -36,7 +37,7 @@ def _test_local_cluster(protocol):
) as cluster:
with Client(cluster) as client:
c = comms.CommsContext(client)
assert sum(c.run(my_rank)) == sum(range(4))
assert sum(c.run(my_rank, 0)) == sum(range(4))


@pytest.mark.parametrize("protocol", ["tcp", "ucx"])
Expand Down Expand Up @@ -290,3 +291,52 @@ def test_jit_unspill(protocol):
p.start()
p.join()
assert not p.exitcode


def _test_lock_workers(scheduler_address, ranks):
async def f(_):
worker = get_worker()
if hasattr(worker, "running"):
assert not worker.running
worker.running = True
await asyncio.sleep(0.5)
assert worker.running
worker.running = False

with Client(scheduler_address) as client:
c = comms.CommsContext(client)
c.run(f, workers=[c.worker_addresses[r] for r in ranks], lock_workers=True)


def test_lock_workers():
"""
Testing `run(...,lock_workers=True)` by spawning 30 runs with overlapping
and non-overlapping worker sets.
"""
try:
from distributed import MultiLock # noqa F401
except ImportError as e:
pytest.skip(str(e))

with LocalCluster(
protocol="tcp",
dashboard_address=None,
n_workers=4,
threads_per_worker=5,
processes=True,
) as cluster:
ps = []
for _ in range(5):
for ranks in [[0, 1], [1, 3], [2, 3]]:
ps.append(
mp.Process(
target=_test_lock_workers,
args=(cluster.scheduler_address, ranks),
)
)
ps[-1].start()

for p in ps:
p.join()

assert all(p.exitcode == 0 for p in ps)
2 changes: 0 additions & 2 deletions dask_cuda/tests/test_spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,6 @@ async def test_cupy_cluster_device_spill(params):
silence_logs=False,
dashboard_address=None,
asynchronous=True,
death_timeout=60,
device_memory_limit=params["device_memory_limit"],
memory_limit=params["memory_limit"],
memory_target_fraction=params["host_target"],
Expand Down Expand Up @@ -365,7 +364,6 @@ async def test_cudf_cluster_device_spill(params):
memory_target_fraction=params["host_target"],
memory_spill_fraction=params["host_spill"],
memory_pause_fraction=params["host_pause"],
death_timeout=60,
asynchronous=True,
) as cluster:
async with Client(cluster, asynchronous=True) as client:
Expand Down

0 comments on commit 65e47f7

Please sign in to comment.