From 7cab0c3e20ba8252a0ad64505c57bd003d39c4a3 Mon Sep 17 00:00:00 2001 From: Mark Harris <783069+harrism@users.noreply.github.com> Date: Thu, 18 Jan 2024 22:43:05 +1100 Subject: [PATCH] Replace call to mr::get_mem_info() (#2099) In the near future memory resources will no longer have `get_mem_info()`. There is one place in RAFT that uses this. This PR replaces it with `rmm::available_device_memory()` which just calls `cudaMemGetInfo()`. Authors: - Mark Harris (https://github.com/harrism) Approvers: - Artem M. Chirkin (https://github.com/achirkin) - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/2099 --- .../raft/neighbors/detail/knn_brute_force.cuh | 24 ++++++++++--------- 1 file changed, 13 insertions(+), 11 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh index 27ef00e385..dff9aceb8d 100644 --- a/cpp/include/raft/neighbors/detail/knn_brute_force.cuh +++ b/cpp/include/raft/neighbors/detail/knn_brute_force.cuh @@ -1,5 +1,5 @@ /* - * Copyright (c) 2023, NVIDIA CORPORATION. + * Copyright (c) 2023-2024, NVIDIA CORPORATION. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -20,14 +20,6 @@ #include #include #include -#include -#include -#include - -#include - -#include -#include #include #include #include @@ -42,9 +34,19 @@ #include #include #include -#include +#include +#include + +#include +#include +#include + #include +#include +#include +#include + namespace raft::neighbors::detail { using namespace raft::spatial::knn::detail; using namespace raft::spatial::knn; @@ -78,7 +80,7 @@ void tiled_brute_force_knn(const raft::resources& handle, size_t tile_cols = 0; auto stream = resource::get_cuda_stream(handle); auto device_memory = resource::get_workspace_resource(handle); - auto total_mem = device_memory->get_mem_info(stream).second; + auto total_mem = rmm::available_device_memory().second; faiss_select::chooseTileSize(m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols); // for unittesting, its convenient to be able to put a max size on the tiles