diff --git a/dask_cuda/tests/test_proxy.py b/dask_cuda/tests/test_proxy.py index 830b403d3..1a4abafe9 100644 --- a/dask_cuda/tests/test_proxy.py +++ b/dask_cuda/tests/test_proxy.py @@ -16,9 +16,10 @@ from dask.sizeof import sizeof from distributed import Client from distributed.protocol.serialize import deserialize, serialize +from distributed.utils_test import gen_test import dask_cuda -from dask_cuda import proxy_object +from dask_cuda import LocalCUDACluster, proxy_object from dask_cuda.disk_io import SpillToDiskFile from dask_cuda.proxify_device_objects import proxify_device_objects from dask_cuda.proxify_host_file import ProxifyHostFile @@ -282,7 +283,8 @@ def test_fixed_attribute_name(): @pytest.mark.parametrize("jit_unspill", [True, False]) -def test_spilling_local_cuda_cluster(jit_unspill): +@gen_test(timeout=20) +async def test_spilling_local_cuda_cluster(jit_unspill): """Testing spilling of a proxied cudf dataframe in a local cuda cluster""" cudf = pytest.importorskip("cudf") dask_cudf = pytest.importorskip("dask_cudf") @@ -299,14 +301,17 @@ def task(x): return x # Notice, setting `device_memory_limit=1B` to trigger spilling - with dask_cuda.LocalCUDACluster( - n_workers=1, device_memory_limit="1B", jit_unspill=jit_unspill + async with LocalCUDACluster( + n_workers=1, + device_memory_limit="1B", + jit_unspill=jit_unspill, + asynchronous=True, ) as cluster: - with Client(cluster): + async with Client(cluster, asynchronous=True) as client: df = cudf.DataFrame({"a": range(10)}) ddf = dask_cudf.from_cudf(df, npartitions=1) ddf = ddf.map_partitions(task, meta=df.head()) - got = ddf.compute() + got = await client.compute(ddf) if isinstance(got, pandas.Series): pytest.xfail( "BUG fixed by " @@ -395,7 +400,8 @@ def _pxy_deserialize(self): @pytest.mark.parametrize("send_serializers", [None, ("dask", "pickle"), ("cuda",)]) @pytest.mark.parametrize("protocol", ["tcp", "ucx"]) -def test_communicating_proxy_objects(protocol, send_serializers): +@gen_test(timeout=20) +async def test_communicating_proxy_objects(protocol, send_serializers): """Testing serialization of cuDF dataframe when communicating""" cudf = pytest.importorskip("cudf") @@ -413,10 +419,13 @@ def task(x): else: assert serializers_used == "dask" - with dask_cuda.LocalCUDACluster( - n_workers=1, protocol=protocol, enable_tcp_over_ucx=protocol == "ucx" + async with dask_cuda.LocalCUDACluster( + n_workers=1, + protocol=protocol, + enable_tcp_over_ucx=protocol == "ucx", + asynchronous=True, ) as cluster: - with Client(cluster) as client: + async with Client(cluster, asynchronous=True) as client: df = cudf.DataFrame({"a": range(10)}) df = proxy_object.asproxy( df, serializers=send_serializers, subclass=_PxyObjTest @@ -429,14 +438,14 @@ def task(x): df._pxy_get().assert_on_deserializing = False else: df._pxy_get().assert_on_deserializing = True - df = client.scatter(df) - client.submit(task, df).result() - client.shutdown() # Avoids a UCX shutdown error + df = await client.scatter(df) + await client.submit(task, df) @pytest.mark.parametrize("protocol", ["tcp", "ucx"]) @pytest.mark.parametrize("shared_fs", [True, False]) -def test_communicating_disk_objects(protocol, shared_fs): +@gen_test(timeout=20) +async def test_communicating_disk_objects(protocol, shared_fs): """Testing disk serialization of cuDF dataframe when communicating""" cudf = pytest.importorskip("cudf") ProxifyHostFile._spill_to_disk.shared_filesystem = shared_fs @@ -450,16 +459,18 @@ def task(x): else: assert serializer_used == "dask" - with dask_cuda.LocalCUDACluster( - n_workers=1, protocol=protocol, enable_tcp_over_ucx=protocol == "ucx" + async with dask_cuda.LocalCUDACluster( + n_workers=1, + protocol=protocol, + enable_tcp_over_ucx=protocol == "ucx", + asynchronous=True, ) as cluster: - with Client(cluster) as client: + async with Client(cluster, asynchronous=True) as client: df = cudf.DataFrame({"a": range(10)}) df = proxy_object.asproxy(df, serializers=("disk",), subclass=_PxyObjTest) df._pxy_get().assert_on_deserializing = False - df = client.scatter(df) - client.submit(task, df).result() - client.shutdown() # Avoids a UCX shutdown error + df = await client.scatter(df) + await client.submit(task, df) @pytest.mark.parametrize("array_module", ["numpy", "cupy"])