diff --git a/python/raft-dask/raft_dask/test/__init__.py b/python/raft-dask/raft_dask/test/__init__.py deleted file mode 100644 index 764e0f32fd..0000000000 --- a/python/raft-dask/raft_dask/test/__init__.py +++ /dev/null @@ -1,13 +0,0 @@ -# Copyright (c) 2020-2023, NVIDIA CORPORATION. -# -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. diff --git a/python/raft-dask/raft_dask/test/test_comms.py b/python/raft-dask/raft_dask/test/test_comms.py index 5c69a94fd8..68c9fee556 100644 --- a/python/raft-dask/raft_dask/test/test_comms.py +++ b/python/raft-dask/raft_dask/test/test_comms.py @@ -17,9 +17,8 @@ import pytest -from dask.distributed import get_worker, wait - -from .conftest import create_client +from dask.distributed import Client, get_worker, wait +from dask_cuda import LocalCUDACluster try: from raft_dask.common import ( @@ -44,6 +43,29 @@ pytestmark = pytest.mark.skip +def create_client(cluster): + """ + Create a Dask distributed client for a specified cluster. + + Parameters + ---------- + cluster : LocalCUDACluster instance or str + If a LocalCUDACluster instance is provided, a client will be created + for it directly. If a string is provided, it should specify the path to + a Dask scheduler file. A client will then be created for the cluster + referenced by this scheduler file. + + Returns + ------- + dask.distributed.Client + A client connected to the specified cluster. + """ + if isinstance(cluster, LocalCUDACluster): + return Client(cluster) + else: + return Client(scheduler_file=cluster) + + def test_comms_init_no_p2p(cluster): client = create_client(cluster) try: