From 8106415dcc759788fc3b5e999650ae0407ff15b8 Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Thu, 29 Jun 2023 23:10:34 +0200 Subject: [PATCH 1/2] Set pool memory resource for raft IVF ANN benchmarks --- cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h | 1 + cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h | 1 + 2 files changed, 2 insertions(+) diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 36b4931460..156b90b6a8 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -97,6 +97,7 @@ RaftIvfFlatGpu::RaftIvfFlatGpu(Metric metric, int dim, const BuildParam mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull) { index_params_.metric = parse_metric_type(metric); + rmm::mr::set_current_device_resource(&managed_mr_); RAFT_CUDA_TRY(cudaGetDevice(&device_)); } diff --git a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h index c390d0bd7e..5c7d1d1eae 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h @@ -98,6 +98,7 @@ RaftIvfPQ::RaftIvfPQ(Metric metric, int dim, const BuildParam& param, f refine_ratio_(refine_ratio), mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull) { + rmm::mr::set_current_device_resource(&mr_); index_params_.metric = parse_metric_type(metric); RAFT_CUDA_TRY(cudaGetDevice(&device_)); } From cd174f838b49df6b907d4b9d2a3e11e7eec7c7de Mon Sep 17 00:00:00 2001 From: Tamas Bela Feher Date: Thu, 29 Jun 2023 23:34:08 +0200 Subject: [PATCH 2/2] Fix merge error --- cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h index 156b90b6a8..cf04e07d19 100644 --- a/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h +++ b/cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h @@ -97,7 +97,7 @@ RaftIvfFlatGpu::RaftIvfFlatGpu(Metric metric, int dim, const BuildParam mr_(rmm::mr::get_current_device_resource(), 1024 * 1024 * 1024ull) { index_params_.metric = parse_metric_type(metric); - rmm::mr::set_current_device_resource(&managed_mr_); + rmm::mr::set_current_device_resource(&mr_); RAFT_CUDA_TRY(cudaGetDevice(&device_)); }