Skip to content

Commit

Permalink
Replace call to mr::get_mem_info() (rapidsai#2099)
Browse files Browse the repository at this point in the history
In the near future memory resources will no longer have `get_mem_info()`.

There is one place in RAFT that uses this. This PR replaces it with `rmm::available_device_memory()` which just calls `cudaMemGetInfo()`.

Authors:
  - Mark Harris (https://github.com/harrism)

Approvers:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: rapidsai#2099
  • Loading branch information
harrism authored Jan 18, 2024
1 parent 1e4961e commit 7cab0c3
Showing 1 changed file with 13 additions and 11 deletions.
24 changes: 13 additions & 11 deletions cpp/include/raft/neighbors/detail/knn_brute_force.cuh
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
/*
* Copyright (c) 2023, NVIDIA CORPORATION.
* Copyright (c) 2023-2024, NVIDIA CORPORATION.
*
* Licensed under the Apache License, Version 2.0 (the "License");
* you may not use this file except in compliance with the License.
Expand All @@ -20,14 +20,6 @@
#include <raft/core/resource/cuda_stream_pool.hpp>
#include <raft/core/resource/device_memory_resource.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>
#include <rmm/cuda_stream_pool.hpp>

#include <rmm/device_uvector.hpp>

#include <cstdint>
#include <iostream>
#include <raft/core/resources.hpp>
#include <raft/distance/detail/distance_ops/l2_exp.cuh>
#include <raft/distance/distance.cuh>
Expand All @@ -42,9 +34,19 @@
#include <raft/spatial/knn/detail/fused_l2_knn.cuh>
#include <raft/spatial/knn/detail/haversine_distance.cuh>
#include <raft/spatial/knn/detail/processing.cuh>
#include <set>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/cuda_device.hpp>
#include <rmm/cuda_stream_pool.hpp>
#include <rmm/device_uvector.hpp>

#include <thrust/iterator/transform_iterator.h>

#include <cstdint>
#include <iostream>
#include <set>

namespace raft::neighbors::detail {
using namespace raft::spatial::knn::detail;
using namespace raft::spatial::knn;
Expand Down Expand Up @@ -78,7 +80,7 @@ void tiled_brute_force_knn(const raft::resources& handle,
size_t tile_cols = 0;
auto stream = resource::get_cuda_stream(handle);
auto device_memory = resource::get_workspace_resource(handle);
auto total_mem = device_memory->get_mem_info(stream).second;
auto total_mem = rmm::available_device_memory().second;
faiss_select::chooseTileSize(m, n, d, sizeof(ElementType), total_mem, tile_rows, tile_cols);

// for unittesting, its convenient to be able to put a max size on the tiles
Expand Down

0 comments on commit 7cab0c3

Please sign in to comment.