-
Notifications
You must be signed in to change notification settings - Fork 197
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[REVIEW] Add scheduler_file argument to support MNMG setup (#1593)
## Add scheduler_file argument to support MNMG setup ### Overview: The primary goal is to provide more flexibility and adaptability in how the Dask cluster for testing is configured. ### Changes: 1. **Allow connecting to an existing cluster** - The creation of the `LocalCUDACluster` instances is now contingent on the presence of a `SCHEDULER_FILE` environment variable. If this variable exists, the path to the Dask scheduler file is returned instead of creating a new cluster. This change allows the use of pre-existing clusters specified via the `SCHEDULER_FILE` environment variable. 2. **Remove UCX related flags as they are no longer needed** - Removed specific flags (`enable_tcp_over_ucx`, `enable_nvlink`, `enable_infiniband`) previously used to initialize the `LocalCUDACluster`. This is because since `Dask-CUDA 22.02` and `UCX >= 1.11.1` we dont need those. See docs: https://docs.rapids.ai/api/dask-cuda/nightly/examples/ucx/#localcudacluster-with-automatic-configuration This could help in situations where test scenarios need to be conducted on a specific pre-existing cluster (especially for MNMG setups) . ### Testing: I tested using the following setup: Start Cluster: ``` dask scheduler --scheduler-file /raid/vjawa/scheduler.json & dask-cuda-worker --scheduler-file /raid/vjawa/scheduler.json ``` Run Tests: ``` export SCHEDULER_FILE=/raid/vjawa/scheduler.json cd /home/nfs/vjawa/raft/python/raft-dask/raft_dask/test pytest . ``` Authors: - Vibhu Jawa (https://github.com/VibhuJawa) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: #1593
- Loading branch information
Showing
4 changed files
with
62 additions
and
33 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
# Copyright (c) 2020-2023, NVIDIA CORPORATION. | ||
# | ||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
# | ||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
# | ||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,54 +1,71 @@ | ||
# Copyright (c) 2022, NVIDIA CORPORATION. | ||
# Copyright (c) 2022-2023, NVIDIA CORPORATION. | ||
|
||
import os | ||
|
||
import pytest | ||
|
||
from dask.distributed import Client | ||
from dask_cuda import LocalCUDACluster, initialize | ||
from dask_cuda import LocalCUDACluster | ||
|
||
os.environ["UCX_LOG_LEVEL"] = "error" | ||
|
||
|
||
enable_tcp_over_ucx = True | ||
enable_nvlink = False | ||
enable_infiniband = False | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def cluster(): | ||
cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) | ||
yield cluster | ||
cluster.close() | ||
scheduler_file = os.environ.get("SCHEDULER_FILE") | ||
if scheduler_file: | ||
yield scheduler_file | ||
else: | ||
cluster = LocalCUDACluster(protocol="tcp", scheduler_port=0) | ||
yield cluster | ||
cluster.close() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def ucx_cluster(): | ||
initialize.initialize( | ||
create_cuda_context=True, | ||
enable_tcp_over_ucx=enable_tcp_over_ucx, | ||
enable_nvlink=enable_nvlink, | ||
enable_infiniband=enable_infiniband, | ||
) | ||
cluster = LocalCUDACluster( | ||
protocol="ucx", | ||
enable_tcp_over_ucx=enable_tcp_over_ucx, | ||
enable_nvlink=enable_nvlink, | ||
enable_infiniband=enable_infiniband, | ||
) | ||
yield cluster | ||
cluster.close() | ||
scheduler_file = os.environ.get("SCHEDULER_FILE") | ||
if scheduler_file: | ||
yield scheduler_file | ||
else: | ||
cluster = LocalCUDACluster( | ||
protocol="ucx", | ||
) | ||
yield cluster | ||
cluster.close() | ||
|
||
|
||
@pytest.fixture(scope="session") | ||
def client(cluster): | ||
client = Client(cluster) | ||
client = create_client(cluster) | ||
yield client | ||
client.close() | ||
|
||
|
||
@pytest.fixture() | ||
def ucx_client(ucx_cluster): | ||
client = Client(cluster) | ||
client = create_client(ucx_cluster) | ||
yield client | ||
client.close() | ||
|
||
|
||
def create_client(cluster): | ||
""" | ||
Create a Dask distributed client for a specified cluster. | ||
Parameters | ||
---------- | ||
cluster : LocalCUDACluster instance or str | ||
If a LocalCUDACluster instance is provided, a client will be created | ||
for it directly. If a string is provided, it should specify the path to | ||
a Dask scheduler file. A client will then be created for the cluster | ||
referenced by this scheduler file. | ||
Returns | ||
------- | ||
dask.distributed.Client | ||
A client connected to the specified cluster. | ||
""" | ||
if isinstance(cluster, LocalCUDACluster): | ||
return Client(cluster) | ||
else: | ||
return Client(scheduler_file=cluster) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters