Skip to content

Commit

Permalink
Convert device_memory_resource* to device_async_resource_ref (#2269)
Browse files Browse the repository at this point in the history
Closes #2261

For reviewers:
Many of changes are simple textual replace of `rmm::mr::device_memory_resource *` with `rmm::device_async_resource_ref`.  However there are several places where RAFT used a default value of `nullptr` for `device_memory_resource*` parameters. This is incompatible with a `resource_ref`, which is a lightweight non-owning reference class, not a pointer. In most places, I was able to either remove the default parameter value, or use `rmm::mr::get_current_device_resource()`. In the case of ivf_pq, I removed the deprecated versions of `search` that took an `mr` parameter.

I removed the unused old src/util/memory_pool.cpp and its headers.

Authors:
  - Mark Harris (https://github.com/harrism)

Approvers:
  - Artem M. Chirkin (https://github.com/achirkin)
  - Dante Gama Dessavre (https://github.com/dantegd)
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #2269
  • Loading branch information
harrism authored Apr 24, 2024
1 parent 317a61c commit 71a19a2
Show file tree
Hide file tree
Showing 59 changed files with 187 additions and 438 deletions.
1 change: 0 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,6 @@ if(RAFT_COMPILE_LIBRARY)
src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu
src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu
src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu
src/util/memory_pool.cpp
)
set_target_properties(
raft_objs
Expand Down
5 changes: 3 additions & 2 deletions cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/failure_callback_resource_adaptor.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

Expand Down Expand Up @@ -130,8 +131,8 @@ class configured_raft_resources {
{
}

configured_raft_resources(configured_raft_resources&&) = default;
configured_raft_resources& operator=(configured_raft_resources&&) = default;
configured_raft_resources(configured_raft_resources&&) = delete;
configured_raft_resources& operator=(configured_raft_resources&&) = delete;
~configured_raft_resources() = default;
configured_raft_resources(const configured_raft_resources& res)
: configured_raft_resources{res.shared_res_}
Expand Down
12 changes: 7 additions & 5 deletions cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <rmm/cuda_device.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#define JSON_DIAGNOSTICS 1
#include <nlohmann/json.hpp>
Expand Down Expand Up @@ -89,10 +90,11 @@ int main(int argc, char** argv)
// and is initially sized to half of free device memory.
rmm::mr::pool_memory_resource<rmm::mr::cuda_memory_resource> pool_mr{
&cuda_mr, rmm::percent_of_free_device_memory(50)};
rmm::mr::set_current_device_resource(
&pool_mr); // Updates the current device resource pointer to `pool_mr`
rmm::mr::device_memory_resource* mr =
rmm::mr::get_current_device_resource(); // Points to `pool_mr`
return raft::bench::ann::run_main(argc, argv);
// Updates the current device resource pointer to `pool_mr`
auto old_mr = rmm::mr::set_current_device_resource(&pool_mr);
auto ret = raft::bench::ann::run_main(argc, argv);
// Restores the current device resource pointer to its previous value
rmm::mr::set_current_device_resource(old_mr);
return ret;
}
#endif
4 changes: 2 additions & 2 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cassert>
#include <fstream>
Expand Down Expand Up @@ -138,7 +138,7 @@ class RaftCagra : public ANN<T>, public AnnGPU {
std::shared_ptr<raft::device_matrix<T, int64_t, row_major>> dataset_;
std::shared_ptr<raft::device_matrix_view<const T, int64_t, row_major>> input_dataset_v_;

inline rmm::mr::device_memory_resource* get_mr(AllocatorType mem_type)
inline rmm::device_async_resource_ref get_mr(AllocatorType mem_type)
{
switch (mem_type) {
case (AllocatorType::HostPinned): return &mr_pinned_;
Expand Down
11 changes: 9 additions & 2 deletions cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -134,7 +134,14 @@ void RaftIvfFlatGpu<T, IdxT>::search(
const T* queries, int batch_size, int k, size_t* neighbors, float* distances) const
{
static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t");
raft::neighbors::ivf_flat::search(
handle_, search_params_, *index_, queries, batch_size, k, (IdxT*)neighbors, distances);
raft::neighbors::ivf_flat::search(handle_,
search_params_,
*index_,
queries,
batch_size,
k,
(IdxT*)neighbors,
distances,
resource::get_workspace_resource(handle_));
}
} // namespace raft::bench::ann
3 changes: 0 additions & 3 deletions cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
#include <raft/neighbors/refine.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

#include <type_traits>

namespace raft::bench::ann {
Expand Down
1 change: 1 addition & 0 deletions cpp/bench/prims/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <rmm/cuda_stream.hpp>
#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_buffer.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

Expand Down
1 change: 1 addition & 0 deletions cpp/bench/prims/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <raft/util/itertools.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

namespace raft::bench::matrix {
Expand Down
15 changes: 12 additions & 3 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -27,10 +27,12 @@
#include <raft/spatial/knn/knn.cuh>
#include <raft/util/itertools.hpp>

#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/host/new_delete_resource.hpp>
#include <rmm/mr/host/pinned_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/sequence.h>

Expand Down Expand Up @@ -101,7 +103,7 @@ struct device_resource {
if (managed_) { delete res_; }
}

[[nodiscard]] auto get() const -> rmm::mr::device_memory_resource* { return res_; }
[[nodiscard]] auto get() const -> rmm::device_async_resource_ref { return res_; }

private:
const bool managed_;
Expand Down Expand Up @@ -158,8 +160,15 @@ struct ivf_flat_knn {
IdxT* out_idxs)
{
search_params.n_probes = 20;
raft::neighbors::ivf_flat::search(
handle, search_params, *index, search_items, ps.n_queries, ps.k, out_idxs, out_dists);
raft::neighbors::ivf_flat::search(handle,
search_params,
*index,
search_items,
ps.n_queries,
ps.k,
out_idxs,
out_dists,
resource::get_workspace_resource(handle));
}
};

Expand Down
1 change: 1 addition & 0 deletions cpp/bench/prims/random/subsample.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_scalar.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>

#include <cub/cub.cuh>
Expand Down
44 changes: 21 additions & 23 deletions cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,14 @@

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_vector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/gather.h>
#include <thrust/transform.h>

#include <limits>
#include <optional>
#include <tuple>
#include <type_traits>

Expand Down Expand Up @@ -91,7 +90,7 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
const MathT* dataset_norm,
IdxT n_rows,
LabelT* labels,
rmm::mr::device_memory_resource* mr)
rmm::device_async_resource_ref mr)
{
auto stream = resource::get_cuda_stream(handle);
switch (params.metric) {
Expand Down Expand Up @@ -263,10 +262,9 @@ void calc_centers_and_sizes(const raft::resources& handle,
const LabelT* labels,
bool reset_counters,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* mr = nullptr)
rmm::device_async_resource_ref mr)
{
auto stream = resource::get_cuda_stream(handle);
if (mr == nullptr) { mr = resource::get_workspace_resource(handle); }

if (!reset_counters) {
raft::linalg::matrixVectorOp(
Expand Down Expand Up @@ -322,12 +320,12 @@ void compute_norm(const raft::resources& handle,
IdxT dim,
IdxT n_rows,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* mr = nullptr)
std::optional<rmm::device_async_resource_ref> mr = std::nullopt)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("compute_norm");
auto stream = resource::get_cuda_stream(handle);
if (mr == nullptr) { mr = resource::get_workspace_resource(handle); }
rmm::device_uvector<MathT> mapped_dataset(0, stream, mr);
rmm::device_uvector<MathT> mapped_dataset(
0, stream, mr.value_or(resource::get_workspace_resource(handle)));

const MathT* dataset_ptr = nullptr;

Expand All @@ -338,7 +336,7 @@ void compute_norm(const raft::resources& handle,

linalg::unaryOp(mapped_dataset.data(), dataset, n_rows * dim, mapping_op, stream);

dataset_ptr = (const MathT*)mapped_dataset.data();
dataset_ptr = static_cast<const MathT*>(mapped_dataset.data());
}

raft::linalg::rowNorm<MathT, IdxT>(
Expand Down Expand Up @@ -376,22 +374,22 @@ void predict(const raft::resources& handle,
IdxT n_rows,
LabelT* labels,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* mr = nullptr,
const MathT* dataset_norm = nullptr)
std::optional<rmm::device_async_resource_ref> mr = std::nullopt,
const MathT* dataset_norm = nullptr)
{
auto stream = resource::get_cuda_stream(handle);
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"predict(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
if (mr == nullptr) { mr = resource::get_workspace_resource(handle); }
auto mem_res = mr.value_or(resource::get_workspace_resource(handle));
auto [max_minibatch_size, _mem_per_row] =
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);
rmm::device_uvector<MathT> cur_dataset(
std::is_same_v<T, MathT> ? 0 : max_minibatch_size * dim, stream, mr);
std::is_same_v<T, MathT> ? 0 : max_minibatch_size * dim, stream, mem_res);
bool need_compute_norm =
dataset_norm == nullptr && (params.metric == raft::distance::DistanceType::L2Expanded ||
params.metric == raft::distance::DistanceType::L2SqrtExpanded);
rmm::device_uvector<MathT> cur_dataset_norm(
need_compute_norm ? max_minibatch_size : 0, stream, mr);
need_compute_norm ? max_minibatch_size : 0, stream, mem_res);
const MathT* dataset_norm_ptr = nullptr;
auto cur_dataset_ptr = cur_dataset.data();
for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) {
Expand All @@ -407,7 +405,7 @@ void predict(const raft::resources& handle,
// Compute the norm now if it hasn't been pre-computed.
if (need_compute_norm) {
compute_norm(
handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mr);
handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mem_res);
dataset_norm_ptr = cur_dataset_norm.data();
} else if (dataset_norm != nullptr) {
dataset_norm_ptr = dataset_norm + offset;
Expand All @@ -422,7 +420,7 @@ void predict(const raft::resources& handle,
dataset_norm_ptr,
minibatch_size,
labels + offset,
mr);
mem_res);
}
}

Expand Down Expand Up @@ -530,7 +528,7 @@ auto adjust_centers(MathT* centers,
MathT threshold,
MappingOpT mapping_op,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* device_memory) -> bool
rmm::device_async_resource_ref device_memory) -> bool
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"adjust_centers(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
Expand Down Expand Up @@ -628,7 +626,7 @@ void balancing_em_iters(const raft::resources& handle,
uint32_t balancing_pullback,
MathT balancing_threshold,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* device_memory)
rmm::device_async_resource_ref device_memory)
{
auto stream = resource::get_cuda_stream(handle);
uint32_t balancing_counter = balancing_pullback;
Expand Down Expand Up @@ -711,7 +709,7 @@ void build_clusters(const raft::resources& handle,
LabelT* cluster_labels,
CounterT* cluster_sizes,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* device_memory,
rmm::device_async_resource_ref device_memory,
const MathT* dataset_norm = nullptr)
{
auto stream = resource::get_cuda_stream(handle);
Expand Down Expand Up @@ -853,8 +851,8 @@ auto build_fine_clusters(const raft::resources& handle,
IdxT fine_clusters_nums_max,
MathT* cluster_centers,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* managed_memory,
rmm::mr::device_memory_resource* device_memory) -> IdxT
rmm::device_async_resource_ref managed_memory,
rmm::device_async_resource_ref device_memory) -> IdxT
{
auto stream = resource::get_cuda_stream(handle);
rmm::device_uvector<IdxT> mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory);
Expand Down Expand Up @@ -971,7 +969,7 @@ void build_hierarchical(const raft::resources& handle,

// TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf.
rmm::mr::managed_memory_resource managed_memory;
rmm::mr::device_memory_resource* device_memory = resource::get_workspace_resource(handle);
rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle);
auto [max_minibatch_size, mem_per_row] =
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);

Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/cluster/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ void calc_centers_and_sizes(const raft::resources& handle,
X.extent(0),
labels.data_handle(),
reset_counters,
mapping_op);
mapping_op,
resource::get_workspace_resource(handle));
}

} // namespace helpers
Expand Down
19 changes: 6 additions & 13 deletions cpp/include/raft/core/device_container_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/device_ptr.h>

Expand Down Expand Up @@ -117,7 +118,7 @@ class device_uvector {
*/
explicit device_uvector(std::size_t size,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
rmm::device_async_resource_ref mr)
: data_{size, stream, mr}
{
}
Expand Down Expand Up @@ -164,19 +165,11 @@ class device_uvector_policy {
public:
auto create(raft::resources const& res, size_t n) -> container_type
{
if (mr_ == nullptr) {
// NB: not using the workspace resource by default!
// The workspace resource is for short-lived temporary allocations.
return container_type(n, resource::get_cuda_stream(res));
} else {
return container_type(n, resource::get_cuda_stream(res), mr_);
}
return container_type(n, resource::get_cuda_stream(res), mr_);
}

constexpr device_uvector_policy() = default;
constexpr explicit device_uvector_policy(rmm::mr::device_memory_resource* mr) noexcept : mr_(mr)
{
}
explicit device_uvector_policy(rmm::device_async_resource_ref mr) noexcept : mr_(mr) {}

[[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference
{
Expand All @@ -192,7 +185,7 @@ class device_uvector_policy {
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; }

private:
rmm::mr::device_memory_resource* mr_{nullptr};
rmm::device_async_resource_ref mr_{rmm::mr::get_current_device_resource()};
};

} // namespace raft
Expand Down
4 changes: 3 additions & 1 deletion cpp/include/raft/core/device_mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <raft/core/mdarray.hpp>
#include <raft/core/resources.hpp>

#include <rmm/resource_ref.hpp>

#include <cstdint>

namespace raft {
Expand Down Expand Up @@ -107,7 +109,7 @@ template <typename ElementType,
typename LayoutPolicy = layout_c_contiguous,
size_t... Extents>
auto make_device_mdarray(raft::resources const& handle,
rmm::mr::device_memory_resource* mr,
rmm::device_async_resource_ref mr,
extents<IndexType, Extents...> exts)
{
using mdarray_t = device_mdarray<ElementType, decltype(exts), LayoutPolicy>;
Expand Down
Loading

0 comments on commit 71a19a2

Please sign in to comment.