Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Migrate from raft::device_resources -> raft::resources #1510

Merged
merged 10 commits into from
May 16, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
The table of contents is too big for display.
Diff view
Diff view
  •  
  •  
  •  
3 changes: 2 additions & 1 deletion cpp/bench/ann/src/raft/raft_ivf_flat_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
#include <raft/core/device_mdspan.hpp>
#include <raft/core/device_resources.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/distance/detail/distance.cuh>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/unary_op.cuh>
Expand Down Expand Up @@ -137,7 +138,7 @@ void RaftIvfFlatGpu<T, IdxT>::search(
static_assert(sizeof(size_t) == sizeof(IdxT), "IdxT is incompatible with size_t");
raft::neighbors::ivf_flat::search(
handle_, search_params_, *index_, queries, batch_size, k, (IdxT*)neighbors, distances, mr_ptr);
handle_.sync_stream();
resource::sync_stream(handle_);
return;
}
} // namespace raft::bench::ann
18 changes: 12 additions & 6 deletions cpp/bench/ann/src/raft/raft_ivf_pq_wrapper.h
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,7 @@
#include <raft/core/host_mdarray.hpp>
#include <raft/core/host_mdspan.hpp>
#include <raft/core/logger.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/linalg/unary_op.cuh>
#include <raft/neighbors/ivf_pq_types.hpp>
Expand Down Expand Up @@ -176,11 +177,14 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries,
auto neighbors_host = raft::make_host_matrix<IdxT, IdxT>(batch_size, k);
auto distances_host = raft::make_host_matrix<float, IdxT>(batch_size, k);

raft::copy(queries_host.data_handle(), queries, queries_host.size(), handle_.get_stream());
raft::copy(queries_host.data_handle(),
queries,
queries_host.size(),
resource::get_cuda_stream(handle_));
raft::copy(candidates_host.data_handle(),
candidates.data_handle(),
candidates_host.size(),
handle_.get_stream());
resource::get_cuda_stream(handle_));

auto dataset_v = raft::make_host_matrix_view<const T, IdxT>(
dataset_.data_handle(), batch_size, index_->dim());
Expand All @@ -196,9 +200,11 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries,
raft::copy(neighbors,
(size_t*)neighbors_host.data_handle(),
neighbors_host.size(),
handle_.get_stream());
raft::copy(
distances, distances_host.data_handle(), distances_host.size(), handle_.get_stream());
resource::get_cuda_stream(handle_));
raft::copy(distances,
distances_host.data_handle(),
distances_host.size(),
resource::get_cuda_stream(handle_));
}
} else {
auto queries_v =
Expand All @@ -209,7 +215,7 @@ void RaftIvfPQ<T, IdxT>::search(const T* queries,
raft::runtime::neighbors::ivf_pq::search(
handle_, search_params_, *index_, queries_v, neighbors_v, distances_v);
}
handle_.sync_stream();
resource::sync_stream(handle_);
return;
}
} // namespace raft::bench::ann
3 changes: 2 additions & 1 deletion cpp/bench/prims/cluster/kmeans_balanced.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@

#include <common/benchmark.hpp>
#include <raft/cluster/kmeans_balanced.cuh>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/random/rng.cuh>

namespace raft::bench::cluster {
Expand Down Expand Up @@ -54,7 +55,7 @@ struct KMeansBalanced : public fixture {
raft::random::uniform(
rng, X.data_handle(), params.data.rows * params.data.cols, kRangeMin, kRangeMax, stream);
}
handle.sync_stream(stream);
resource::sync_stream(handle, stream);
}

void allocate_temp_buffers(const ::benchmark::State& state) override
Expand Down
5 changes: 3 additions & 2 deletions cpp/bench/prims/common/benchmark.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <memory>
#include <raft/core/resource/cuda_stream.hpp>

#include <raft/core/detail/macros.hpp>
#include <raft/core/device_mdarray.hpp>
Expand Down Expand Up @@ -113,7 +114,7 @@ class fixture {
raft::device_resources handle;
rmm::cuda_stream_view stream;

fixture(bool use_pool_memory_resource = false) : stream{handle.get_stream()}
fixture(bool use_pool_memory_resource = false) : stream{resource::get_cuda_stream(handle)}
{
// Cache memory pool between test runs, since it is expensive to create.
// This speeds up the time required to run the select_k bench by over 3x.
Expand Down Expand Up @@ -209,7 +210,7 @@ class BlobsFixture : public fixture {
(T)blobs_params.center_box_min,
(T)blobs_params.center_box_max,
blobs_params.seed);
this->handle.sync_stream(stream);
resource::sync_stream(this->handle, stream);
}

protected:
Expand Down
3 changes: 2 additions & 1 deletion cpp/bench/prims/distance/fused_l2_nn.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <common/benchmark.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/distance/fused_l2_nn.cuh>
#include <raft/linalg/norm.cuh>
#include <raft/util/cudart_utils.hpp>
Expand Down Expand Up @@ -74,7 +75,7 @@ struct fusedl2nn : public fixture {
raft::linalg::L2Norm,
true,
stream);
handle.sync_stream(stream);
resource::sync_stream(handle, stream);
}

void allocate_temp_buffers(const ::benchmark::State& state) override
Expand Down
3 changes: 2 additions & 1 deletion cpp/bench/prims/distance/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
#include <common/benchmark.hpp>
#include <memory>
#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cublas_handle.hpp>
#include <raft/distance/distance_types.hpp>
#include <raft/distance/kernels.cuh>
#include <raft/random/rng.cuh>
Expand All @@ -40,7 +41,7 @@ struct GramMatrix : public fixture {
: params(p), handle(stream), A(0, stream), B(0, stream), C(0, stream)
{
kernel = std::unique_ptr<GramMatrixBase<T>>(
KernelFactory<T>::create(p.kernel_params, handle.get_cublas_handle()));
KernelFactory<T>::create(p.kernel_params, resource::get_cublas_handle(handle)));

A.resize(params.m * params.k, stream);
B.resize(params.k * params.n, stream);
Expand Down
3 changes: 2 additions & 1 deletion cpp/bench/prims/matrix/argmin.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <common/benchmark.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/matrix/argmin.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>
Expand All @@ -40,7 +41,7 @@ struct Argmin : public fixture {
raft::random::RngState rng{1234};
raft::random::uniform(
rng, matrix.data_handle(), params.rows * params.cols, T(-1), T(1), stream);
handle.sync_stream(stream);
resource::sync_stream(handle, stream);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
3 changes: 2 additions & 1 deletion cpp/bench/prims/matrix/gather.cu
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
*/

#include <common/benchmark.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/matrix/gather.cuh>
#include <raft/random/rng.cuh>
#include <raft/util/itertools.hpp>
Expand Down Expand Up @@ -57,7 +58,7 @@ struct Gather : public fixture {
if constexpr (Conditional) {
raft::random::uniform(rng, stencil.data_handle(), params.map_length, T(-1), T(1), stream);
}
handle.sync_stream(stream);
resource::sync_stream(handle, stream);
}

void run_benchmark(::benchmark::State& state) override
Expand Down
19 changes: 13 additions & 6 deletions cpp/bench/prims/neighbors/knn.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
#pragma once

#include <common/benchmark.hpp>
#include <raft/core/resource/device_id.hpp>

#include <raft/random/rng.cuh>

Expand Down Expand Up @@ -311,12 +312,18 @@ struct knn : public fixture {
RAFT_CUDA_TRY(cudaHostGetDevicePointer(&data_ptr, data_host_.data(), 0));
break;
case TransferStrategy::MANAGED: // sic! using std::memcpy rather than cuda copy
RAFT_CUDA_TRY(cudaMemAdvise(
data_ptr, allocation_size, cudaMemAdviseSetPreferredLocation, handle.get_device()));
RAFT_CUDA_TRY(cudaMemAdvise(
data_ptr, allocation_size, cudaMemAdviseSetAccessedBy, handle.get_device()));
RAFT_CUDA_TRY(cudaMemAdvise(
data_ptr, allocation_size, cudaMemAdviseSetReadMostly, handle.get_device()));
RAFT_CUDA_TRY(cudaMemAdvise(data_ptr,
allocation_size,
cudaMemAdviseSetPreferredLocation,
resource::get_device_id(handle)));
RAFT_CUDA_TRY(cudaMemAdvise(data_ptr,
allocation_size,
cudaMemAdviseSetAccessedBy,
resource::get_device_id(handle)));
RAFT_CUDA_TRY(cudaMemAdvise(data_ptr,
allocation_size,
cudaMemAdviseSetReadMostly,
resource::get_device_id(handle)));
std::memcpy(data_ptr, data_host_.data(), allocation_size);
break;
default: break;
Expand Down
16 changes: 9 additions & 7 deletions cpp/include/raft/cluster/detail/agglomerative.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#pragma once

#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand Down Expand Up @@ -100,7 +102,7 @@ class UnionFind {
* @param[out] out_size cluster sizes of output
*/
template <typename value_idx, typename value_t>
void build_dendrogram_host(raft::device_resources const& handle,
void build_dendrogram_host(raft::resources const& handle,
const value_idx* rows,
const value_idx* cols,
const value_t* data,
Expand All @@ -109,7 +111,7 @@ void build_dendrogram_host(raft::device_resources const& handle,
value_t* out_delta,
value_idx* out_size)
{
auto stream = handle.get_stream();
auto stream = resource::get_cuda_stream(handle);

value_idx n_edges = nnz;

Expand All @@ -121,7 +123,7 @@ void build_dendrogram_host(raft::device_resources const& handle,
update_host(mst_dst_h.data(), cols, n_edges, stream);
update_host(mst_weights_h.data(), data, n_edges, stream);

handle.sync_stream(stream);
resource::sync_stream(handle, stream);

std::vector<value_idx> children_h(n_edges * 2);
std::vector<value_idx> out_size_h(n_edges);
Expand Down Expand Up @@ -236,14 +238,14 @@ struct init_label_roots {
* @param n_leaves
*/
template <typename value_idx, int tpb = 256>
void extract_flattened_clusters(raft::device_resources const& handle,
void extract_flattened_clusters(raft::resources const& handle,
value_idx* labels,
const value_idx* children,
size_t n_clusters,
size_t n_leaves)
{
auto stream = handle.get_stream();
auto thrust_policy = handle.get_thrust_policy();
auto stream = resource::get_cuda_stream(handle);
auto thrust_policy = resource::get_thrust_policy(handle);

// Handle special case where n_clusters == 1
if (n_clusters == 1) {
Expand Down
26 changes: 14 additions & 12 deletions cpp/include/raft/cluster/detail/connectivities.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -16,7 +16,9 @@

#pragma once

#include <raft/core/device_resources.hpp>
#include <raft/core/resource/cuda_stream.hpp>
#include <raft/core/resource/thrust_policy.hpp>
#include <raft/core/resources.hpp>
#include <raft/util/cuda_utils.cuh>
#include <raft/util/cudart_utils.hpp>

Expand All @@ -40,7 +42,7 @@ namespace raft::cluster::detail {

template <raft::cluster::LinkageDistance dist_type, typename value_idx, typename value_t>
struct distance_graph_impl {
void run(raft::device_resources const& handle,
void run(raft::resources const& handle,
const value_t* X,
size_t m,
size_t n,
Expand All @@ -58,7 +60,7 @@ struct distance_graph_impl {
*/
template <typename value_idx, typename value_t>
struct distance_graph_impl<raft::cluster::LinkageDistance::KNN_GRAPH, value_idx, value_t> {
void run(raft::device_resources const& handle,
void run(raft::resources const& handle,
const value_t* X,
size_t m,
size_t n,
Expand All @@ -68,8 +70,8 @@ struct distance_graph_impl<raft::cluster::LinkageDistance::KNN_GRAPH, value_idx,
rmm::device_uvector<value_t>& data,
int c)
{
auto stream = handle.get_stream();
auto thrust_policy = handle.get_thrust_policy();
auto stream = resource::get_cuda_stream(handle);
auto thrust_policy = resource::get_thrust_policy(handle);

// Need to symmetrize knn into undirected graph
raft::sparse::COO<value_t, value_idx> knn_graph_coo(stream);
Expand Down Expand Up @@ -127,7 +129,7 @@ __global__ void fill_indices2(value_idx* indices, size_t m, size_t nnz)
* @param[out] data
*/
template <typename value_idx, typename value_t>
void pairwise_distances(const raft::device_resources& handle,
void pairwise_distances(const raft::resources& handle,
const value_t* X,
size_t m,
size_t n,
Expand All @@ -136,8 +138,8 @@ void pairwise_distances(const raft::device_resources& handle,
value_idx* indices,
value_t* data)
{
auto stream = handle.get_stream();
auto exec_policy = handle.get_thrust_policy();
auto stream = resource::get_cuda_stream(handle);
auto exec_policy = resource::get_thrust_policy(handle);

value_idx nnz = m * m;

Expand Down Expand Up @@ -175,7 +177,7 @@ void pairwise_distances(const raft::device_resources& handle,
*/
template <typename value_idx, typename value_t>
struct distance_graph_impl<raft::cluster::LinkageDistance::PAIRWISE, value_idx, value_t> {
void run(const raft::device_resources& handle,
void run(const raft::resources& handle,
const value_t* X,
size_t m,
size_t n,
Expand All @@ -185,7 +187,7 @@ struct distance_graph_impl<raft::cluster::LinkageDistance::PAIRWISE, value_idx,
rmm::device_uvector<value_t>& data,
int c)
{
auto stream = handle.get_stream();
auto stream = resource::get_cuda_stream(handle);

size_t nnz = m * m;

Expand Down Expand Up @@ -213,7 +215,7 @@ struct distance_graph_impl<raft::cluster::LinkageDistance::PAIRWISE, value_idx,
* which will guarantee k <= log(n) + c
*/
template <typename value_idx, typename value_t, raft::cluster::LinkageDistance dist_type>
void get_distance_graph(raft::device_resources const& handle,
void get_distance_graph(raft::resources const& handle,
const value_t* X,
size_t m,
size_t n,
Expand All @@ -223,7 +225,7 @@ void get_distance_graph(raft::device_resources const& handle,
rmm::device_uvector<value_t>& data,
int c)
{
auto stream = handle.get_stream();
auto stream = resource::get_cuda_stream(handle);

indptr.resize(m + 1, stream);

Expand Down
Loading