diff --git a/dask_cuda/benchmarks/utils.py b/dask_cuda/benchmarks/utils.py index a3d51066a..a7f51ce9b 100644 --- a/dask_cuda/benchmarks/utils.py +++ b/dask_cuda/benchmarks/utils.py @@ -364,6 +364,7 @@ def setup_memory_pool( import cupy import rmm + from rmm.allocators.cupy import rmm_cupy_allocator from dask_cuda.utils import get_rmm_log_file_name @@ -380,7 +381,7 @@ def setup_memory_pool( logging=logging, log_file_name=get_rmm_log_file_name(dask_worker, logging, log_directory), ) - cupy.cuda.set_allocator(rmm.rmm_cupy_allocator) + cupy.cuda.set_allocator(rmm_cupy_allocator) if statistics: rmm.mr.set_current_device_resource( rmm.mr.StatisticsResourceAdaptor(rmm.mr.get_current_device_resource())