Skip to content

Commit

Permalink
Remove explicit UCX config from tests (#1199)
Browse files Browse the repository at this point in the history
Rely on UCX defaults for selection of transport in tests, which is now the preferred way to launch setup a cluster.

Authors:
  - Peter Andreas Entschev (https://github.com/pentschev)

Approvers:
  - Benjamin Zaitlen (https://github.com/quasiben)

URL: #1199
  • Loading branch information
pentschev authored Jun 13, 2023
1 parent c97ac50 commit 83c6476
Show file tree
Hide file tree
Showing 3 changed files with 14 additions and 36 deletions.
31 changes: 0 additions & 31 deletions dask_cuda/tests/test_explicit_comms.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,6 @@
import dask_cuda
from dask_cuda.explicit_comms import comms
from dask_cuda.explicit_comms.dataframe.shuffle import shuffle as explicit_comms_shuffle
from dask_cuda.initialize import initialize
from dask_cuda.utils import get_ucx_config

mp = mp.get_context("spawn") # type: ignore
ucp = pytest.importorskip("ucp")
Expand All @@ -32,14 +30,6 @@ async def my_rank(state, arg):


def _test_local_cluster(protocol):
dask.config.update(
dask.config.global_config,
{
"distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True),
},
priority="new",
)

with LocalCluster(
protocol=protocol,
dashboard_address=None,
Expand Down Expand Up @@ -106,15 +96,6 @@ def check_partitions(df, npartitions):
def _test_dataframe_shuffle(backend, protocol, n_workers):
if backend == "cudf":
cudf = pytest.importorskip("cudf")
initialize(enable_tcp_over_ucx=True)
else:
dask.config.update(
dask.config.global_config,
{
"distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True),
},
priority="new",
)

with LocalCluster(
protocol=protocol,
Expand Down Expand Up @@ -220,17 +201,6 @@ def _test_dataframe_shuffle_merge(backend, protocol, n_workers):
if backend == "cudf":
cudf = pytest.importorskip("cudf")

initialize(enable_tcp_over_ucx=True)
else:

dask.config.update(
dask.config.global_config,
{
"distributed.comm.ucx": get_ucx_config(enable_tcp_over_ucx=True),
},
priority="new",
)

with LocalCluster(
protocol=protocol,
dashboard_address=None,
Expand Down Expand Up @@ -287,7 +257,6 @@ def _test_jit_unspill(protocol):
threads_per_worker=1,
jit_unspill=True,
device_memory_limit="1B",
enable_tcp_over_ucx=True if protocol == "ucx" else False,
) as cluster:
with Client(cluster):
np.random.seed(42)
Expand Down
17 changes: 14 additions & 3 deletions dask_cuda/tests/test_local_cuda_cluster.py
Original file line number Diff line number Diff line change
Expand Up @@ -87,14 +87,25 @@ def get_visible_devices():
}


@pytest.mark.parametrize("protocol", ["ucx", None])
@gen_test(timeout=20)
async def test_ucx_protocol(protocol):
async def test_ucx_protocol():
pytest.importorskip("ucp")

async with LocalCUDACluster(
protocol="ucx", asynchronous=True, data=dict
) as cluster:
assert all(
ws.address.startswith("ucx://") for ws in cluster.scheduler.workers.values()
)


@gen_test(timeout=20)
async def test_explicit_ucx_with_protocol_none():
pytest.importorskip("ucp")

initialize(enable_tcp_over_ucx=True)
async with LocalCUDACluster(
protocol=protocol, enable_tcp_over_ucx=True, asynchronous=True, data=dict
protocol=None, enable_tcp_over_ucx=True, asynchronous=True, data=dict
) as cluster:
assert all(
ws.address.startswith("ucx://") for ws in cluster.scheduler.workers.values()
Expand Down
2 changes: 0 additions & 2 deletions dask_cuda/tests/test_proxy.py
Original file line number Diff line number Diff line change
Expand Up @@ -422,7 +422,6 @@ def task(x):
async with dask_cuda.LocalCUDACluster(
n_workers=1,
protocol=protocol,
enable_tcp_over_ucx=protocol == "ucx",
asynchronous=True,
) as cluster:
async with Client(cluster, asynchronous=True) as client:
Expand Down Expand Up @@ -462,7 +461,6 @@ def task(x):
async with dask_cuda.LocalCUDACluster(
n_workers=1,
protocol=protocol,
enable_tcp_over_ucx=protocol == "ucx",
asynchronous=True,
) as cluster:
async with Client(cluster, asynchronous=True) as client:
Expand Down

0 comments on commit 83c6476

Please sign in to comment.