Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Convert device_memory_resource* to device_async_resource_ref #2269

Merged
merged 6 commits into from
Apr 24, 2024
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 0 additions & 1 deletion cpp/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -565,7 +565,6 @@ if(RAFT_COMPILE_LIBRARY)
src/spatial/knn/detail/fused_l2_knn_int32_t_float.cu
src/spatial/knn/detail/fused_l2_knn_int64_t_float.cu
src/spatial/knn/detail/fused_l2_knn_uint32_t_float.cu
src/util/memory_pool.cpp
)
set_target_properties(
raft_objs
Expand Down
3 changes: 2 additions & 1 deletion cpp/bench/ann/src/raft/raft_ann_bench_utils.h
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/failure_callback_resource_adaptor.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <memory>
#include <type_traits>
Expand Down Expand Up @@ -100,7 +101,7 @@ class shared_raft_resources {
~shared_raft_resources() noexcept { rmm::mr::set_current_device_resource(orig_resource_); }
achirkin marked this conversation as resolved.
Show resolved Hide resolved

private:
rmm::mr::device_memory_resource* orig_resource_;
rmm::device_async_resource_ref orig_resource_;
pool_mr_type pool_resource_;
mr_type resource_;
};
Expand Down
3 changes: 2 additions & 1 deletion cpp/bench/ann/src/raft/raft_cagra_hnswlib.cu
Original file line number Diff line number Diff line change
Expand Up @@ -20,6 +20,7 @@

#include <rmm/cuda_device.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#define JSON_DIAGNOSTICS 1
#include <nlohmann/json.hpp>
Expand Down Expand Up @@ -91,7 +92,7 @@ int main(int argc, char** argv)
&cuda_mr, rmm::percent_of_free_device_memory(50)};
rmm::mr::set_current_device_resource(
&pool_mr); // Updates the current device resource pointer to `pool_mr`
rmm::mr::device_memory_resource* mr =
rmm::device_async_resource_ref mr =
rmm::mr::get_current_device_resource(); // Points to `pool_mr`
return raft::bench::ann::run_main(argc, argv);
}
Expand Down
4 changes: 2 additions & 2 deletions cpp/bench/ann/src/raft/raft_cagra_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,7 @@
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cassert>
#include <fstream>
Expand Down Expand Up @@ -138,7 +138,7 @@ class RaftCagra : public ANN<T>, public AnnGPU {
std::shared_ptr<raft::device_matrix<T, int64_t, row_major>> dataset_;
std::shared_ptr<raft::device_matrix_view<const T, int64_t, row_major>> input_dataset_v_;

inline rmm::mr::device_memory_resource* get_mr(AllocatorType mem_type)
inline rmm::device_async_resource_ref get_mr(AllocatorType mem_type)
{
switch (mem_type) {
case (AllocatorType::HostPinned): return &mr_pinned_;
Expand Down
3 changes: 0 additions & 3 deletions cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,6 @@
#include <raft/neighbors/refine.cuh>
#include <raft/util/cudart_utils.hpp>

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

#include <type_traits>

namespace raft::bench::ann {
Expand Down
3 changes: 2 additions & 1 deletion cpp/bench/prims/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
#include <rmm/device_buffer.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <benchmark/benchmark.h>

Expand All @@ -43,7 +44,7 @@ namespace raft::bench {
*/
struct using_pool_memory_res {
private:
rmm::mr::device_memory_resource* orig_res_;
rmm::device_async_resource_ref orig_res_;
rmm::mr::cuda_memory_resource cuda_res_{};
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_res_;

Expand Down
3 changes: 2 additions & 1 deletion cpp/bench/prims/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@

#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

namespace raft::bench::matrix {

Expand Down Expand Up @@ -108,7 +109,7 @@ struct Gather : public fixture {

private:
GatherParams<IdxT> params;
rmm::mr::device_memory_resource* old_mr;
rmm::device_async_resource_ref old_mr;
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr;
raft::device_matrix<T, IdxT> matrix, out;
raft::host_matrix<T, IdxT> matrix_h;
Expand Down
5 changes: 3 additions & 2 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/mr/host/new_delete_resource.hpp>
#include <rmm/mr/host/pinned_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/sequence.h>

Expand Down Expand Up @@ -101,11 +102,11 @@ struct device_resource {
if (managed_) { delete res_; }
}

[[nodiscard]] auto get() const -> rmm::mr::device_memory_resource* { return res_; }
[[nodiscard]] auto get() const -> rmm::device_async_resource_ref { return res_; }

private:
const bool managed_;
rmm::mr::device_memory_resource* res_;
rmm::device_async_resource_ref res_;
};

template <typename T>
Expand Down
3 changes: 2 additions & 1 deletion cpp/bench/prims/random/subsample.cu
Original file line number Diff line number Diff line change
Expand Up @@ -28,6 +28,7 @@

#include <rmm/device_scalar.hpp>
#include <rmm/mr/device/pool_memory_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <cub/cub.cuh>

Expand Down Expand Up @@ -94,7 +95,7 @@ struct sample : public fixture {
private:
float GiB = 1073741824.0f;
raft::device_resources res;
rmm::mr::device_memory_resource* old_mr;
rmm::device_async_resource_ref old_mr;
rmm::mr::pool_memory_resource<rmm::mr::device_memory_resource> pool_mr;
sample_inputs params;
raft::device_vector<T, int64_t> out, in;
Expand Down
44 changes: 21 additions & 23 deletions cpp/include/raft/cluster/detail/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -43,15 +43,14 @@

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_scalar.hpp>
#include <rmm/device_vector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/managed_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/gather.h>
#include <thrust/transform.h>

#include <limits>
#include <optional>
#include <tuple>
#include <type_traits>

Expand Down Expand Up @@ -91,7 +90,7 @@ inline std::enable_if_t<std::is_floating_point_v<MathT>> predict_core(
const MathT* dataset_norm,
IdxT n_rows,
LabelT* labels,
rmm::mr::device_memory_resource* mr)
rmm::device_async_resource_ref mr)
{
auto stream = resource::get_cuda_stream(handle);
switch (params.metric) {
Expand Down Expand Up @@ -263,10 +262,9 @@ void calc_centers_and_sizes(const raft::resources& handle,
const LabelT* labels,
bool reset_counters,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* mr = nullptr)
rmm::device_async_resource_ref mr)
{
auto stream = resource::get_cuda_stream(handle);
if (mr == nullptr) { mr = resource::get_workspace_resource(handle); }

if (!reset_counters) {
raft::linalg::matrixVectorOp(
Expand Down Expand Up @@ -322,12 +320,12 @@ void compute_norm(const raft::resources& handle,
IdxT dim,
IdxT n_rows,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* mr = nullptr)
std::optional<rmm::device_async_resource_ref> mr = std::nullopt)
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope("compute_norm");
auto stream = resource::get_cuda_stream(handle);
if (mr == nullptr) { mr = resource::get_workspace_resource(handle); }
rmm::device_uvector<MathT> mapped_dataset(0, stream, mr);
rmm::device_uvector<MathT> mapped_dataset(
0, stream, mr.value_or(resource::get_workspace_resource(handle)));

const MathT* dataset_ptr = nullptr;

Expand All @@ -338,7 +336,7 @@ void compute_norm(const raft::resources& handle,

linalg::unaryOp(mapped_dataset.data(), dataset, n_rows * dim, mapping_op, stream);

dataset_ptr = (const MathT*)mapped_dataset.data();
dataset_ptr = static_cast<const MathT*>(mapped_dataset.data());
}

raft::linalg::rowNorm<MathT, IdxT>(
Expand Down Expand Up @@ -376,22 +374,22 @@ void predict(const raft::resources& handle,
IdxT n_rows,
LabelT* labels,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* mr = nullptr,
const MathT* dataset_norm = nullptr)
std::optional<rmm::device_async_resource_ref> mr = std::nullopt,
const MathT* dataset_norm = nullptr)
{
auto stream = resource::get_cuda_stream(handle);
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"predict(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
if (mr == nullptr) { mr = resource::get_workspace_resource(handle); }
auto mem_res = mr.value_or(resource::get_workspace_resource(handle));
auto [max_minibatch_size, _mem_per_row] =
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);
rmm::device_uvector<MathT> cur_dataset(
std::is_same_v<T, MathT> ? 0 : max_minibatch_size * dim, stream, mr);
std::is_same_v<T, MathT> ? 0 : max_minibatch_size * dim, stream, mem_res);
bool need_compute_norm =
dataset_norm == nullptr && (params.metric == raft::distance::DistanceType::L2Expanded ||
params.metric == raft::distance::DistanceType::L2SqrtExpanded);
rmm::device_uvector<MathT> cur_dataset_norm(
need_compute_norm ? max_minibatch_size : 0, stream, mr);
need_compute_norm ? max_minibatch_size : 0, stream, mem_res);
const MathT* dataset_norm_ptr = nullptr;
auto cur_dataset_ptr = cur_dataset.data();
for (IdxT offset = 0; offset < n_rows; offset += max_minibatch_size) {
Expand All @@ -407,7 +405,7 @@ void predict(const raft::resources& handle,
// Compute the norm now if it hasn't been pre-computed.
if (need_compute_norm) {
compute_norm(
handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mr);
handle, cur_dataset_norm.data(), cur_dataset_ptr, dim, minibatch_size, mapping_op, mem_res);
dataset_norm_ptr = cur_dataset_norm.data();
} else if (dataset_norm != nullptr) {
dataset_norm_ptr = dataset_norm + offset;
Expand All @@ -422,7 +420,7 @@ void predict(const raft::resources& handle,
dataset_norm_ptr,
minibatch_size,
labels + offset,
mr);
mem_res);
}
}

Expand Down Expand Up @@ -530,7 +528,7 @@ auto adjust_centers(MathT* centers,
MathT threshold,
MappingOpT mapping_op,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* device_memory) -> bool
rmm::device_async_resource_ref device_memory) -> bool
{
common::nvtx::range<common::nvtx::domain::raft> fun_scope(
"adjust_centers(%zu, %u)", static_cast<size_t>(n_rows), n_clusters);
Expand Down Expand Up @@ -628,7 +626,7 @@ void balancing_em_iters(const raft::resources& handle,
uint32_t balancing_pullback,
MathT balancing_threshold,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* device_memory)
rmm::device_async_resource_ref device_memory)
{
auto stream = resource::get_cuda_stream(handle);
uint32_t balancing_counter = balancing_pullback;
Expand Down Expand Up @@ -711,7 +709,7 @@ void build_clusters(const raft::resources& handle,
LabelT* cluster_labels,
CounterT* cluster_sizes,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* device_memory,
rmm::device_async_resource_ref device_memory,
const MathT* dataset_norm = nullptr)
{
auto stream = resource::get_cuda_stream(handle);
Expand Down Expand Up @@ -853,8 +851,8 @@ auto build_fine_clusters(const raft::resources& handle,
IdxT fine_clusters_nums_max,
MathT* cluster_centers,
MappingOpT mapping_op,
rmm::mr::device_memory_resource* managed_memory,
rmm::mr::device_memory_resource* device_memory) -> IdxT
rmm::device_async_resource_ref managed_memory,
rmm::device_async_resource_ref device_memory) -> IdxT
{
auto stream = resource::get_cuda_stream(handle);
rmm::device_uvector<IdxT> mc_trainset_ids_buf(mesocluster_size_max, stream, managed_memory);
Expand Down Expand Up @@ -971,7 +969,7 @@ void build_hierarchical(const raft::resources& handle,

// TODO: Remove the explicit managed memory- we shouldn't be creating this on the user's behalf.
rmm::mr::managed_memory_resource managed_memory;
rmm::mr::device_memory_resource* device_memory = resource::get_workspace_resource(handle);
rmm::device_async_resource_ref device_memory = resource::get_workspace_resource(handle);
auto [max_minibatch_size, mem_per_row] =
calc_minibatch_size<MathT>(n_clusters, n_rows, dim, params.metric, std::is_same_v<T, MathT>);

Expand Down
3 changes: 2 additions & 1 deletion cpp/include/raft/cluster/kmeans_balanced.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -358,7 +358,8 @@ void calc_centers_and_sizes(const raft::resources& handle,
X.extent(0),
labels.data_handle(),
reset_counters,
mapping_op);
mapping_op,
resource::get_workspace_resource(handle));
}

} // namespace helpers
Expand Down
19 changes: 6 additions & 13 deletions cpp/include/raft/core/device_container_policy.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,8 @@

#include <rmm/cuda_stream_view.hpp>
#include <rmm/device_uvector.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>
#include <rmm/mr/device/per_device_resource.hpp>
#include <rmm/resource_ref.hpp>

#include <thrust/device_ptr.h>

Expand Down Expand Up @@ -117,7 +118,7 @@ class device_uvector {
*/
explicit device_uvector(std::size_t size,
rmm::cuda_stream_view stream,
rmm::mr::device_memory_resource* mr)
rmm::device_async_resource_ref mr)
: data_{size, stream, mr}
{
}
Expand Down Expand Up @@ -164,19 +165,11 @@ class device_uvector_policy {
public:
auto create(raft::resources const& res, size_t n) -> container_type
{
if (mr_ == nullptr) {
// NB: not using the workspace resource by default!
// The workspace resource is for short-lived temporary allocations.
return container_type(n, resource::get_cuda_stream(res));
} else {
return container_type(n, resource::get_cuda_stream(res), mr_);
}
return container_type(n, resource::get_cuda_stream(res), mr_);
}

constexpr device_uvector_policy() = default;
constexpr explicit device_uvector_policy(rmm::mr::device_memory_resource* mr) noexcept : mr_(mr)
{
}
explicit device_uvector_policy(rmm::device_async_resource_ref mr) noexcept : mr_(mr) {}

[[nodiscard]] constexpr auto access(container_type& c, size_t n) const noexcept -> reference
{
Expand All @@ -192,7 +185,7 @@ class device_uvector_policy {
[[nodiscard]] auto make_accessor_policy() const noexcept { return const_accessor_policy{}; }

private:
rmm::mr::device_memory_resource* mr_{nullptr};
rmm::device_async_resource_ref mr_{rmm::mr::get_current_device_resource()};
};

} // namespace raft
Expand Down
4 changes: 3 additions & 1 deletion cpp/include/raft/core/device_mdarray.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,8 @@
#include <raft/core/mdarray.hpp>
#include <raft/core/resources.hpp>

#include <rmm/resource_ref.hpp>

#include <cstdint>

namespace raft {
Expand Down Expand Up @@ -107,7 +109,7 @@ template <typename ElementType,
typename LayoutPolicy = layout_c_contiguous,
size_t... Extents>
auto make_device_mdarray(raft::resources const& handle,
rmm::mr::device_memory_resource* mr,
rmm::device_async_resource_ref mr,
extents<IndexType, Extents...> exts)
{
using mdarray_t = device_mdarray<ElementType, decltype(exts), LayoutPolicy>;
Expand Down
1 change: 1 addition & 0 deletions cpp/include/raft/core/device_resources.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -37,6 +37,7 @@

#include <rmm/cuda_stream_pool.hpp>
#include <rmm/exec_policy.hpp>
#include <rmm/mr/device/device_memory_resource.hpp>

#include <cuda_runtime.h>

Expand Down
3 changes: 1 addition & 2 deletions cpp/include/raft/distance/detail/masked_nn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -256,9 +256,8 @@ void masked_l2_nn_impl(raft::resources const& handle,
static_assert(P::Mblk == 64, "masked_l2_nn_impl only supports a policy with 64 rows per block.");

// Get stream and workspace memory resource
rmm::mr::device_memory_resource* ws_mr =
dynamic_cast<rmm::mr::device_memory_resource*>(resource::get_workspace_resource(handle));
auto stream = resource::get_cuda_stream(handle);
auto ws_mr = resource::get_workspace_resource(handle);

// Acquire temporary buffers and initialize to zero:
// 1) Adjacency matrix bitfield
Expand Down
Loading
Loading