diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 36b4931460..cf04e07d19 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -97,6 +97,7 @@ RaftIvfFlatGpu::RaftIvfFlatGpu(Metric metric, int dim, const BuildParam mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull) { index_params_.metric = parse_metric_type(metric); + rmm::mr::set_current_device_resource(&mr_); RAFT_CUDA_TRY(cudaGetDevice(&device_)); } diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index c390d0bd7e..5c7d1d1eae 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -98,6 +98,7 @@ RaftIvfPQ::RaftIvfPQ(Metric metric, int dim, const BuildParam& param, f refine_ratio_(refine_ratio), mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull) { + rmm::mr::set_current_device_resource(&mr_); index_params_.metric = parse_metric_type(metric); RAFT_CUDA_TRY(cudaGetDevice(&device_)); }