diff --git a/cpp/include/raft/neighbors/cagra.cuh b/cpp/include/raft/neighbors/cagra.cuh index 6bb7beca55..903d0571dc 100644 --- a/cpp/include/raft/neighbors/cagra.cuh +++ b/cpp/include/raft/neighbors/cagra.cuh @@ -256,13 +256,17 @@ index build(raft::resources const& res, graph_degree = intermediate_degree; } - auto knn_graph = raft::make_host_matrix(dataset.extent(0), intermediate_degree); + std::optional> knn_graph( + raft::make_host_matrix(dataset.extent(0), intermediate_degree)); - build_knn_graph(res, dataset, knn_graph.view()); + build_knn_graph(res, dataset, knn_graph->view()); auto cagra_graph = raft::make_host_matrix(dataset.extent(0), graph_degree); - optimize(res, knn_graph.view(), cagra_graph.view()); + optimize(res, knn_graph->view(), cagra_graph.view()); + + // free intermediate graph before trying to create the index + knn_graph.reset(); // Construct an index from dataset and optimized knn graph. return index(res, params.metric, dataset, raft::make_const_mdspan(cagra_graph.view())); diff --git a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh index 0558d7ea39..18d451be60 100644 --- a/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/graph_core.cuh @@ -334,18 +334,13 @@ void optimize(raft::resources const& res, auto output_graph_ptr = new_graph.data_handle(); const IdxT graph_size = new_graph.extent(0); - auto pruned_graph = raft::make_host_matrix(graph_size, output_graph_degree); - { // // Prune kNN graph // - auto d_input_graph = - raft::make_device_matrix(res, graph_size, input_graph_degree); - - auto detour_count = raft::make_host_matrix(graph_size, input_graph_degree); auto d_detour_count = raft::make_device_matrix(res, graph_size, input_graph_degree); + RAFT_CUDA_TRY(cudaMemsetAsync(d_detour_count.data_handle(), 0xff, graph_size * input_graph_degree * sizeof(uint8_t), @@ -376,24 +371,13 @@ void optimize(raft::resources const& res, const double time_prune_start = cur_time(); RAFT_LOG_DEBUG("# Pruning kNN Graph on GPUs\r"); - raft::copy(d_input_graph.data_handle(), - input_graph_ptr, - graph_size * input_graph_degree, - resource::get_cuda_stream(res)); - void (*kernel_prune)(const IdxT* const, - const uint32_t, - const uint32_t, - const uint32_t, - const uint32_t, - const uint32_t, - uint8_t* const, - uint32_t* const, - uint64_t* const); + // Copy input_graph_ptr over to device if necessary + device_matrix_view_from_host d_input_graph( + res, + raft::make_host_matrix_view(input_graph_ptr, graph_size, input_graph_degree)); constexpr int MAX_DEGREE = 1024; - if (input_graph_degree <= MAX_DEGREE) { - kernel_prune = kern_prune; - } else { + if (input_graph_degree > MAX_DEGREE) { RAFT_FAIL( "The degree of input knn graph is too large (%u). " "It must be equal to or smaller than %d.", @@ -410,16 +394,17 @@ void optimize(raft::resources const& res, dev_stats.data_handle(), 0, sizeof(uint64_t) * 2, resource::get_cuda_stream(res))); for (uint32_t i_batch = 0; i_batch < num_batch; i_batch++) { - kernel_prune<<>>( - d_input_graph.data_handle(), - graph_size, - input_graph_degree, - output_graph_degree, - batch_size, - i_batch, - d_detour_count.data_handle(), - d_num_no_detour_edges.data_handle(), - dev_stats.data_handle()); + kern_prune + <<>>( + d_input_graph.data_handle(), + graph_size, + input_graph_degree, + output_graph_degree, + batch_size, + i_batch, + d_detour_count.data_handle(), + d_num_no_detour_edges.data_handle(), + dev_stats.data_handle()); resource::sync_stream(res); RAFT_LOG_DEBUG( "# Pruning kNN Graph on GPUs (%.1lf %%)\r", @@ -428,10 +413,7 @@ void optimize(raft::resources const& res, resource::sync_stream(res); RAFT_LOG_DEBUG("\n"); - raft::copy(detour_count.data_handle(), - d_detour_count.data_handle(), - graph_size * input_graph_degree, - resource::get_cuda_stream(res)); + host_matrix_view_from_device detour_count(res, d_detour_count.view()); raft::copy( host_stats.data_handle(), dev_stats.data_handle(), 2, resource::get_cuda_stream(res)); @@ -447,7 +429,7 @@ void optimize(raft::resources const& res, if (max_detour < num_detour) { max_detour = num_detour; /* stats */ } for (uint64_t k = 0; k < input_graph_degree; k++) { if (detour_count.data_handle()[k + (input_graph_degree * i)] != num_detour) { continue; } - pruned_graph.data_handle()[pk + (output_graph_degree * i)] = + output_graph_ptr[pk + (output_graph_degree * i)] = input_graph_ptr[k + (input_graph_degree * i)]; pk += 1; if (pk >= output_graph_degree) break; @@ -478,8 +460,7 @@ void optimize(raft::resources const& res, // const double time_make_start = cur_time(); - auto d_rev_graph = - raft::make_device_matrix(res, graph_size, output_graph_degree); + device_matrix_view_from_host d_rev_graph(res, rev_graph.view()); RAFT_CUDA_TRY(cudaMemsetAsync(d_rev_graph.data_handle(), 0xff, graph_size * output_graph_degree * sizeof(IdxT), @@ -497,7 +478,7 @@ void optimize(raft::resources const& res, for (uint64_t k = 0; k < output_graph_degree; k++) { #pragma omp parallel for for (uint64_t i = 0; i < graph_size; i++) { - dest_nodes.data_handle()[i] = pruned_graph.data_handle()[k + (output_graph_degree * i)]; + dest_nodes.data_handle()[i] = output_graph_ptr[k + (output_graph_degree * i)]; } resource::sync_stream(res); @@ -520,10 +501,12 @@ void optimize(raft::resources const& res, resource::sync_stream(res); RAFT_LOG_DEBUG("\n"); - raft::copy(rev_graph.data_handle(), - d_rev_graph.data_handle(), - graph_size * output_graph_degree, - resource::get_cuda_stream(res)); + if (d_rev_graph.allocated_memory()) { + raft::copy(rev_graph.data_handle(), + d_rev_graph.data_handle(), + graph_size * output_graph_degree, + resource::get_cuda_stream(res)); + } raft::copy(rev_graph_count.data_handle(), d_rev_graph_count.data_handle(), graph_size, @@ -542,10 +525,6 @@ void optimize(raft::resources const& res, const uint64_t num_protected_edges = output_graph_degree / 2; RAFT_LOG_DEBUG("# num_protected_edges: %lu", num_protected_edges); - memcpy(output_graph_ptr, - pruned_graph.data_handle(), - sizeof(IdxT) * graph_size * output_graph_degree); - constexpr int _omp_chunk = 1024; #pragma omp parallel for schedule(dynamic, _omp_chunk) for (uint64_t j = 0; j < graph_size; j++) { @@ -578,7 +557,7 @@ void optimize(raft::resources const& res, #pragma omp parallel for reduction(+ : num_replaced_edges) for (uint64_t i = 0; i < graph_size; i++) { for (uint64_t k = 0; k < output_graph_degree; k++) { - const uint64_t j = pruned_graph.data_handle()[k + (output_graph_degree * i)]; + const uint64_t j = output_graph_ptr[k + (output_graph_degree * i)]; const uint64_t pos = pos_in_array(j, output_graph_ptr + (output_graph_degree * i), output_graph_degree); if (pos == output_graph_degree) { num_replaced_edges += 1; } diff --git a/cpp/include/raft/neighbors/detail/cagra/utils.hpp b/cpp/include/raft/neighbors/detail/cagra/utils.hpp index 22c7a60647..22cbe6bbac 100644 --- a/cpp/include/raft/neighbors/detail/cagra/utils.hpp +++ b/cpp/include/raft/neighbors/detail/cagra/utils.hpp @@ -20,6 +20,8 @@ #include #include #include +#include +#include #include namespace raft::neighbors::cagra::detail { @@ -150,4 +152,97 @@ struct gen_index_msb_1_mask { }; } // namespace utils +/** + * Utility to sync memory from a host_matrix_view to a device_matrix_view + * + * In certain situations (UVM/HMM/ATS) host memory might be directly accessible on the + * device, and no extra allocations need to be performed. This class checks + * if the host_matrix_view is already accessible on the device, and only creates device + * memory and copies over if necessary. In memory limited situations this is preferable + * to having both a host and device copy + * TODO: once the mdbuffer changes here https://github.com/wphicks/raft/blob/fea-mdbuffer + * have been merged, we should remove this class and switch over to using mdbuffer for this + */ +template +class device_matrix_view_from_host { + public: + device_matrix_view_from_host(raft::resources const& res, host_matrix_view host_view) + : host_view_(host_view) + { + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); + device_ptr = reinterpret_cast(attr.devicePointer); + if (device_ptr == NULL) { + // allocate memory and copy over + device_mem_.emplace( + raft::make_device_matrix(res, host_view.extent(0), host_view.extent(1))); + raft::copy(device_mem_->data_handle(), + host_view.data_handle(), + host_view.extent(0) * host_view.extent(1), + resource::get_cuda_stream(res)); + device_ptr = device_mem_->data_handle(); + } + } + + device_matrix_view view() + { + return make_device_matrix_view(device_ptr, host_view_.extent(0), host_view_.extent(1)); + } + + T* data_handle() { return device_ptr; } + + bool allocated_memory() const { return device_mem_.has_value(); } + + private: + std::optional> device_mem_; + host_matrix_view host_view_; + T* device_ptr; +}; + +/** + * Utility to sync memory from a device_matrix_view to a host_matrix_view + * + * In certain situations (UVM/HMM/ATS) device memory might be directly accessible on the + * host, and no extra allocations need to be performed. This class checks + * if the device_matrix_view is already accessible on the host, and only creates host + * memory and copies over if necessary. In memory limited situations this is preferable + * to having both a host and device copy + * TODO: once the mdbuffer changes here https://github.com/wphicks/raft/blob/fea-mdbuffer + * have been merged, we should remove this class and switch over to using mdbuffer for this + */ +template +class host_matrix_view_from_device { + public: + host_matrix_view_from_device(raft::resources const& res, device_matrix_view device_view) + : device_view_(device_view) + { + cudaPointerAttributes attr; + RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, device_view.data_handle())); + host_ptr = reinterpret_cast(attr.hostPointer); + if (host_ptr == NULL) { + // allocate memory and copy over + host_mem_.emplace( + raft::make_host_matrix(device_view.extent(0), device_view.extent(1))); + raft::copy(host_mem_->data_handle(), + device_view.data_handle(), + device_view.extent(0) * device_view.extent(1), + resource::get_cuda_stream(res)); + host_ptr = host_mem_->data_handle(); + } + } + + host_matrix_view view() + { + return make_host_matrix_view(host_ptr, device_view_.extent(0), device_view_.extent(1)); + } + + T* data_handle() { return host_ptr; } + + bool allocated_memory() const { return host_mem_.has_value(); } + + private: + std::optional> host_mem_; + device_matrix_view device_view_; + T* host_ptr; +}; } // namespace raft::neighbors::cagra::detail