diff --git a/dask_cuda/explicit_comms/dataframe/shuffle.py b/dask_cuda/explicit_comms/dataframe/shuffle.py index 84bc55701..d79b08a40 100644 --- a/dask_cuda/explicit_comms/dataframe/shuffle.py +++ b/dask_cuda/explicit_comms/dataframe/shuffle.py @@ -477,9 +477,14 @@ def shuffle( # Get batchsize max_num_inkeys = max(len(k) for k in rank_to_inkeys.values()) - batchsize = batchsize or dask.config.get("explicit_comms-batchsize", 1) + batchsize = batchsize or dask.config.get("explicit-comms-batchsize", 1) if batchsize == -1: batchsize = max_num_inkeys + if not isinstance(batchsize, int) or batchsize < 0: + raise ValueError( + "explicit-comms-batchsize must be a " + f"positive integer or -1 (was '{batchsize}')" + ) # Get number of rounds of dataframe partitioning and all-to-all communication. num_rounds = ceil(max_num_inkeys / batchsize) diff --git a/dask_cuda/tests/test_explicit_comms.py b/dask_cuda/tests/test_explicit_comms.py index 88e1294cb..413bf5bdd 100644 --- a/dask_cuda/tests/test_explicit_comms.py +++ b/dask_cuda/tests/test_explicit_comms.py @@ -1,5 +1,7 @@ import asyncio import multiprocessing as mp +import os +from unittest.mock import patch import numpy as np import pandas as pd @@ -172,8 +174,9 @@ def test_dataframe_shuffle(backend, protocol, nworkers): assert not p.exitcode -def _test_dask_use_explicit_comms(): - def check_shuffle(in_cluster): +@pytest.mark.parametrize("in_cluster", [True, False]) +def test_dask_use_explicit_comms(in_cluster): + def check_shuffle(): """Check if shuffle use explicit-comms by search for keys named 'explicit-comms-shuffle' """ @@ -189,23 +192,28 @@ def check_shuffle(in_cluster): else: # If not in cluster, we cannot use explicit comms assert all(name not in str(key) for key in res.dask) - with LocalCluster( - protocol="tcp", - dashboard_address=None, - n_workers=2, - threads_per_worker=1, - processes=True, - ) as cluster: - with Client(cluster): - check_shuffle(True) - check_shuffle(False) - - -def test_dask_use_explicit_comms(): - p = mp.Process(target=_test_dask_use_explicit_comms) - p.start() - p.join() - assert not p.exitcode + if in_cluster: + # We check environment variables by setting an illegal batchsize + with patch.dict( + os.environ, + {"DASK_EXPLICIT_COMMS": "1", "DASK_EXPLICIT_COMMS_BATCHSIZE": "-2"}, + ): + dask.config.refresh() # Trigger re-read of the environment variables + with pytest.raises(ValueError, match="explicit-comms-batchsize"): + ddf.shuffle(on="key", npartitions=4, shuffle="tasks") + + if in_cluster: + with LocalCluster( + protocol="tcp", + dashboard_address=None, + n_workers=2, + threads_per_worker=1, + processes=True, + ) as cluster: + with Client(cluster): + check_shuffle() + else: + check_shuffle() def _test_dataframe_shuffle_merge(backend, protocol, n_workers):