From 6f0abae95527eddcaaf0c6f3198e28875a000911 Mon Sep 17 00:00:00 2001 From: Vibhu Jawa Date: Thu, 22 Jun 2023 13:01:58 -0700 Subject: [PATCH] [REVIEW] Add scheduler_file argument to support MNMG setup (#1593) ## Add scheduler_file argument to support MNMG setup ### Overview: The primary goal is to provide more flexibility and adaptability in how the Dask cluster for testing is configured. ### Changes: 1. **Allow connecting to an existing cluster** - The creation of the `LocalCUDACluster` instances is now contingent on the presence of a `SCHEDULER_FILE` environment variable. If this variable exists, the path to the Dask scheduler file is returned instead of creating a new cluster. This change allows the use of pre-existing clusters specified via the `SCHEDULER_FILE` environment variable. 2. **Remove UCX related flags as they are no longer needed** - Removed specific flags (`enable_tcp_over_ucx`, `enable_nvlink`, `enable_infiniband`) previously used to initialize the `LocalCUDACluster`. This is because since `Dask-CUDA 22.02` and `UCX >= 1.11.1` we dont need those. See docs: https://docs.rapids.ai/api/dask-cuda/nightly/examples/ucx/#localcudacluster-with-automatic-configuration This could help in situations where test scenarios need to be conducted on a specific pre-existing cluster (especially for MNMG setups) . ### Testing: I tested using the following setup: Start Cluster: ``` dask scheduler --scheduler-file /raid/vjawa/scheduler.json & dask-cuda-worker --scheduler-file /raid/vjawa/scheduler.json ``` Run Tests: ``` export SCHEDULER_FILE=/raid/vjawa/scheduler.json cd /home/nfs/vjawa/raft/python/raft-dask/raft_dask/test pytest . ``` Authors: - Vibhu Jawa (https://github.com/VibhuJawa) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1593 --- python/raft-dask/raft_dask/common/utils.py | 2 +- python/raft-dask/raft_dask/test/__init__.py | 13 ++++ python/raft-dask/raft_dask/test/conftest.py | 69 ++++++++++++------- python/raft-dask/raft_dask/test/test_comms.py | 11 ++- 4 files changed, 62 insertions(+), 33 deletions(-) create mode 100644 python/raft-dask/raft_dask/test/__init__.py diff --git a/python/raft-dask/raft_dask/common/utils.py b/python/raft-dask/raft_dask/common/utils.py index 78a899aa50..dcc53fda9a 100644 --- a/python/raft-dask/raft_dask/common/utils.py +++ b/python/raft-dask/raft_dask/common/utils.py @@ -1,4 +1,4 @@ -# Copyright (c) 2020-2022, NVIDIA CORPORATION. +# Copyright (c) 2020-2023, NVIDIA CORPORATION. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. diff --git a/python/raft-dask/raft_dask/test/__init__.py b/python/raft-dask/raft_dask/test/__init__.py new file mode 100644 index 0000000000..764e0f32fd --- /dev/null +++ b/python/raft-dask/raft_dask/test/__init__.py @@ -0,0 +1,13 @@ +# Copyright (c) 2020-2023, NVIDIA CORPORATION. +# +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. diff --git a/python/raft-dask/raft_dask/test/conftest.py b/python/raft-dask/raft_dask/test/conftest.py index 39ee21cbaa..d1baa684d4 100644 --- a/python/raft-dask/raft_dask/test/conftest.py +++ b/python/raft-dask/raft_dask/test/conftest.py @@ -1,54 +1,71 @@ -# Copyright (c) 2022, NVIDIA CORPORATION. +# Copyright (c) 2022-2023, NVIDIA CORPORATION. import os import pytest from dask.distributed import Client -from dask_cuda import LocalCUDACluster, initialize +from dask_cuda import LocalCUDACluster os.environ["UCX_LOG_LEVEL"] = "error" -enable_tcp_over_ucx = True -enable_nvlink = False -enable_infiniband = False - - @pytest.fixture(scope="session") def cluster(): - cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) - yield cluster - cluster.close() + scheduler_file = os.environ.get("SCHEDULER_FILE") + if scheduler_file: + yield scheduler_file + else: + cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) + yield cluster + cluster.close() @pytest.fixture(scope="session") def ucx_cluster(): - initialize.initialize( - create_cuda_context=True, - enable_tcp_over_ucx=enable_tcp_over_ucx, - enable_nvlink=enable_nvlink, - enable_infiniband=enable_infiniband, - ) - cluster = LocalCUDACluster( - protocol="ucx", - enable_tcp_over_ucx=enable_tcp_over_ucx, - enable_nvlink=enable_nvlink, - enable_infiniband=enable_infiniband, - ) - yield cluster - cluster.close() + scheduler_file = os.environ.get("SCHEDULER_FILE") + if scheduler_file: + yield scheduler_file + else: + cluster = LocalCUDACluster( + protocol="ucx", + ) + yield cluster + cluster.close() @pytest.fixture(scope="session") def client(cluster): - client = Client(cluster) + client = create_client(cluster) yield client client.close() @pytest.fixture() def ucx_client(ucx_cluster): - client = Client(cluster) + client = create_client(ucx_cluster) yield client client.close() + + +def create_client(cluster): + """ + Create a Dask distributed client for a specified cluster. + + Parameters + ---------- + cluster : LocalCUDACluster instance or str + If a LocalCUDACluster instance is provided, a client will be created + for it directly. If a string is provided, it should specify the path to + a Dask scheduler file. A client will then be created for the cluster + referenced by this scheduler file. + + Returns + ------- + dask.distributed.Client + A client connected to the specified cluster. + """ + if isinstance(cluster, LocalCUDACluster): + return Client(cluster) + else: + return Client(scheduler_file=cluster) diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 3a430f9270..5c69a94fd8 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -17,7 +17,9 @@ import pytest -from dask.distributed import Client, get_worker, wait +from dask.distributed import get_worker, wait + +from .conftest import create_client try: from raft_dask.common import ( @@ -43,9 +45,7 @@ def test_comms_init_no_p2p(cluster): - - client = Client(cluster) - + client = create_client(cluster) try: cb = Comms(verbose=True) cb.init() @@ -121,8 +121,7 @@ def func_check_uid_on_worker(sessionId, uniqueId, dask_worker=None): def test_handles(cluster): - - client = Client(cluster) + client = create_client(cluster) def _has_handle(sessionId): return local_handle(sessionId, dask_worker=get_worker()) is not None