Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Move some dask_cuda.utils pieces to their own modules #1263

Merged
merged 8 commits into from
Oct 26, 2023
4 changes: 1 addition & 3 deletions dask_cuda/cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,9 @@

from .device_host_file import DeviceHostFile
from .initialize import initialize
from .plugins import CPUAffinity, PreImport, RMMSetup
from .proxify_host_file import ProxifyHostFile
from .utils import (
CPUAffinity,
PreImport,
RMMSetup,
cuda_visible_devices,
get_cpu_affinity,
get_n_gpus,
Expand Down
29 changes: 13 additions & 16 deletions dask_cuda/local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,19 +3,16 @@
import os
import warnings
from functools import partial
from typing import Literal

import dask
from distributed import LocalCluster, Nanny, Worker
from distributed.worker_memory import parse_memory_limit

from .device_host_file import DeviceHostFile
from .initialize import initialize
from .plugins import CPUAffinity, PreImport, RMMSetup
from .proxify_host_file import ProxifyHostFile
from .utils import (
CPUAffinity,
PreImport,
RMMSetup,
cuda_visible_devices,
get_cpu_affinity,
get_ucx_config,
Expand All @@ -25,13 +22,6 @@
)


class IncreasedCloseTimeoutNanny(Nanny):
async def close( # type:ignore[override]
self, timeout: float = 10.0, reason: str = "nanny-close"
) -> Literal["OK"]:
return await super().close(timeout=timeout, reason=reason)


class LoggedWorker(Worker):
def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
Expand All @@ -41,7 +31,7 @@ async def start(self):
self.data.set_address(self.address)


class LoggedNanny(IncreasedCloseTimeoutNanny):
class LoggedNanny(Nanny):
def __init__(self, *args, **kwargs):
super().__init__(*args, worker_class=LoggedWorker, **kwargs)

Expand Down Expand Up @@ -342,10 +332,17 @@ def __init__(
enable_rdmacm=enable_rdmacm,
)

worker_class = partial(
LoggedNanny if log_spilling is True else IncreasedCloseTimeoutNanny,
worker_class=worker_class,
)
if worker_class is not None:
if log_spilling is True:
raise ValueError(
"Cannot enable `log_spilling` when `worker_class` is specified. If "
"logging is needed, ensure `worker_class` is a subclass of "
"`distributed.local_cuda_cluster.LoggedNanny` or a subclass of "
"`distributed.local_cuda_cluster.LoggedWorker`, and specify "
"`log_spilling=False`."
)
if not issubclass(worker_class, Nanny):
worker_class = partial(Nanny, worker_class=worker_class)

self.pre_import = pre_import

Expand Down
122 changes: 122 additions & 0 deletions dask_cuda/plugins.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,122 @@
import importlib
import os

from distributed import WorkerPlugin

from .utils import get_rmm_log_file_name, parse_device_memory_limit


class CPUAffinity(WorkerPlugin):
def __init__(self, cores):
self.cores = cores

def setup(self, worker=None):
os.sched_setaffinity(0, self.cores)


class RMMSetup(WorkerPlugin):
def __init__(
self,
initial_pool_size,
maximum_pool_size,
managed_memory,
async_alloc,
release_threshold,
log_directory,
track_allocations,
):
if initial_pool_size is None and maximum_pool_size is not None:
raise ValueError(
"`rmm_maximum_pool_size` was specified without specifying "
"`rmm_pool_size`.`rmm_pool_size` must be specified to use RMM pool."
)
if async_alloc is True:
if managed_memory is True:
raise ValueError(
"`rmm_managed_memory` is incompatible with the `rmm_async`."
)
if async_alloc is False and release_threshold is not None:
raise ValueError("`rmm_release_threshold` requires `rmm_async`.")

self.initial_pool_size = initial_pool_size
self.maximum_pool_size = maximum_pool_size
self.managed_memory = managed_memory
self.async_alloc = async_alloc
self.release_threshold = release_threshold
self.logging = log_directory is not None
self.log_directory = log_directory
self.rmm_track_allocations = track_allocations

def setup(self, worker=None):
if self.initial_pool_size is not None:
self.initial_pool_size = parse_device_memory_limit(
self.initial_pool_size, alignment_size=256
)

if self.async_alloc:
import rmm

if self.release_threshold is not None:
self.release_threshold = parse_device_memory_limit(
self.release_threshold, alignment_size=256
)

mr = rmm.mr.CudaAsyncMemoryResource(
initial_pool_size=self.initial_pool_size,
release_threshold=self.release_threshold,
)

if self.maximum_pool_size is not None:
self.maximum_pool_size = parse_device_memory_limit(
self.maximum_pool_size, alignment_size=256
)
mr = rmm.mr.LimitingResourceAdaptor(
mr, allocation_limit=self.maximum_pool_size
)

rmm.mr.set_current_device_resource(mr)
if self.logging:
rmm.enable_logging(
log_file_name=get_rmm_log_file_name(
worker, self.logging, self.log_directory
)
)
elif self.initial_pool_size is not None or self.managed_memory:
import rmm

pool_allocator = False if self.initial_pool_size is None else True

if self.initial_pool_size is not None:
if self.maximum_pool_size is not None:
self.maximum_pool_size = parse_device_memory_limit(
self.maximum_pool_size, alignment_size=256
)

rmm.reinitialize(
pool_allocator=pool_allocator,
managed_memory=self.managed_memory,
initial_pool_size=self.initial_pool_size,
maximum_pool_size=self.maximum_pool_size,
logging=self.logging,
log_file_name=get_rmm_log_file_name(
worker, self.logging, self.log_directory
),
)
if self.rmm_track_allocations:
import rmm

mr = rmm.mr.get_current_device_resource()
rmm.mr.set_current_device_resource(rmm.mr.TrackingResourceAdaptor(mr))


class PreImport(WorkerPlugin):
def __init__(self, libraries):
if libraries is None:
libraries = []
elif isinstance(libraries, str):
libraries = libraries.split(",")
self.libraries = libraries

def setup(self, worker=None):
for l in self.libraries:
importlib.import_module(l)
6 changes: 3 additions & 3 deletions dask_cuda/tests/test_dask_cuda_worker.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def test_cuda_visible_devices_and_memory_limit_and_nthreads(loop): # noqa: F811
str(nthreads),
"--no-dashboard",
"--worker-class",
"dask_cuda.utils.MockWorker",
"dask_cuda.utils_test.MockWorker",
]
):
with Client("127.0.0.1:9359", loop=loop) as client:
Expand Down Expand Up @@ -329,7 +329,7 @@ def test_cuda_mig_visible_devices_and_memory_limit_and_nthreads(loop): # noqa:
str(nthreads),
"--no-dashboard",
"--worker-class",
"dask_cuda.utils.MockWorker",
"dask_cuda.utils_test.MockWorker",
]
):
with Client("127.0.0.1:9359", loop=loop) as client:
Expand Down Expand Up @@ -364,7 +364,7 @@ def test_cuda_visible_devices_uuid(loop): # noqa: F811
"127.0.0.1",
"--no-dashboard",
"--worker-class",
"dask_cuda.utils.MockWorker",
"dask_cuda.utils_test.MockWorker",
]
):
with Client("127.0.0.1:9359", loop=loop) as client:
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@
import dask_cuda
from dask_cuda.explicit_comms import comms
from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle
from dask_cuda.local_cuda_cluster import IncreasedCloseTimeoutNanny
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny

mp = mp.get_context("spawn") # type: ignore
ucp = pytest.importorskip("ucp")
Expand Down
5 changes: 5 additions & 0 deletions dask_cuda/tests/test_initialize.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@

from dask_cuda.initialize import initialize
from dask_cuda.utils import get_ucx_config
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny

mp = mp.get_context("spawn") # type: ignore
ucp = pytest.importorskip("ucp")
Expand All @@ -29,6 +30,7 @@ def _test_initialize_ucx_tcp():
n_workers=1,
threads_per_worker=1,
processes=True,
worker_class=IncreasedCloseTimeoutNanny,
config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
) as cluster:
with Client(cluster) as client:
Expand Down Expand Up @@ -64,6 +66,7 @@ def _test_initialize_ucx_nvlink():
n_workers=1,
threads_per_worker=1,
processes=True,
worker_class=IncreasedCloseTimeoutNanny,
config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
) as cluster:
with Client(cluster) as client:
Expand Down Expand Up @@ -100,6 +103,7 @@ def _test_initialize_ucx_infiniband():
n_workers=1,
threads_per_worker=1,
processes=True,
worker_class=IncreasedCloseTimeoutNanny,
config={"distributed.comm.ucx": get_ucx_config(**kwargs)},
) as cluster:
with Client(cluster) as client:
Expand Down Expand Up @@ -138,6 +142,7 @@ def _test_initialize_ucx_all():
n_workers=1,
threads_per_worker=1,
processes=True,
worker_class=IncreasedCloseTimeoutNanny,
config={"distributed.comm.ucx": get_ucx_config()},
) as cluster:
with Client(cluster) as client:
Expand Down
2 changes: 1 addition & 1 deletion dask_cuda/tests/test_local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,13 +13,13 @@
from dask_cuda import CUDAWorker, LocalCUDACluster, utils
from dask_cuda.initialize import initialize
from dask_cuda.utils import (
MockWorker,
get_cluster_configuration,
get_device_total_memory,
get_gpu_count_mig,
get_gpu_uuid_from_index,
print_cluster_config,
)
from dask_cuda.utils_test import MockWorker


@gen_test(timeout=20)
Expand Down
6 changes: 5 additions & 1 deletion dask_cuda/tests/test_proxify_host_file.py
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@
from dask_cuda.proxify_host_file import ProxifyHostFile
from dask_cuda.proxy_object import ProxyObject, asproxy, unproxy
from dask_cuda.utils import get_device_total_memory
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny

cupy = pytest.importorskip("cupy")
cupy.cuda.set_allocator(None)
Expand Down Expand Up @@ -393,7 +394,10 @@ def is_proxy_object(x):

with dask.config.set(jit_unspill_compatibility_mode=compatibility_mode):
async with dask_cuda.LocalCUDACluster(
n_workers=1, jit_unspill=True, asynchronous=True
n_workers=1,
jit_unspill=True,
worker_class=IncreasedCloseTimeoutNanny,
asynchronous=True,
) as cluster:
async with Client(cluster, asynchronous=True) as client:
ddf = dask.dataframe.from_pandas(
Expand Down
2 changes: 2 additions & 0 deletions dask_cuda/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
from dask_cuda.disk_io import SpillToDiskFile
from dask_cuda.proxify_device_objects import proxify_device_objects
from dask_cuda.proxify_host_file import ProxifyHostFile
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny

# Make the "disk" serializer available and use a directory that are
# remove on exit.
Expand Down Expand Up @@ -422,6 +423,7 @@ def task(x):
async with dask_cuda.LocalCUDACluster(
n_workers=1,
protocol=protocol,
worker_class=IncreasedCloseTimeoutNanny,
asynchronous=True,
) as cluster:
async with Client(cluster, asynchronous=True) as client:
Expand Down
3 changes: 3 additions & 0 deletions dask_cuda/tests/test_spill.py
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@
from distributed.utils_test import gen_cluster, gen_test, loop # noqa: F401

from dask_cuda import LocalCUDACluster, utils
from dask_cuda.utils_test import IncreasedCloseTimeoutNanny

if utils.get_device_total_memory() < 1e10:
pytest.skip("Not enough GPU memory", allow_module_level=True)
Expand Down Expand Up @@ -160,6 +161,7 @@ async def test_cupy_cluster_device_spill(params):
asynchronous=True,
device_memory_limit=params["device_memory_limit"],
memory_limit=params["memory_limit"],
worker_class=IncreasedCloseTimeoutNanny,
) as cluster:
async with Client(cluster, asynchronous=True) as client:

Expand Down Expand Up @@ -263,6 +265,7 @@ async def test_cudf_cluster_device_spill(params):
asynchronous=True,
device_memory_limit=params["device_memory_limit"],
memory_limit=params["memory_limit"],
worker_class=IncreasedCloseTimeoutNanny,
) as cluster:
async with Client(cluster, asynchronous=True) as client:

Expand Down
Loading