diff --git a/cpp/include/raft/matrix/detail/gather.cuh b/cpp/include/raft/matrix/detail/gather.cuh index 73072ec841..767b8721a9 100644 --- a/cpp/include/raft/matrix/detail/gather.cuh +++ b/cpp/include/raft/matrix/detail/gather.cuh @@ -17,9 +17,15 @@ #pragma once #include +#include +#include +#include +#include #include +#include +#include +#include #include - namespace raft { namespace matrix { namespace detail { @@ -335,6 +341,74 @@ void gather_if(const InputIteratorT in, gatherImpl(in, D, N, map, stencil, map_length, out, pred_op, transform_op, stream); } +template +void gather_buff(host_matrix_view dataset, + host_vector_view indices, + IdxT offset, + pinned_matrix_view buff) +{ + raft::common::nvtx::range fun_scope("Gather vectors"); + + IdxT batch_size = std::min(buff.extent(0), indices.extent(0) - offset); + +#pragma omp for + for (IdxT i = 0; i < batch_size; i++) { + IdxT in_idx = indices(offset + i); + for (IdxT k = 0; k < buff.extent(1); k++) { + buff(i, k) = dataset(in_idx, k); + } + } +} + +template +void gather(raft::resources const& res, + host_matrix_view dataset, + device_vector_view indices, + raft::device_matrix_view output) +{ + IdxT n_dim = output.extent(1); + IdxT n_train = output.extent(0); + auto indices_host = raft::make_host_vector(n_train); + raft::copy( + indices_host.data_handle(), indices.data_handle(), n_train, resource::get_cuda_stream(res)); + resource::sync_stream(res); + + const size_t max_batch_size = 32768; + // Gather the vector on the host in tmp buffers. We use two buffers to overlap H2D sync + // and gathering the data. + raft::common::nvtx::push_range("subsample::alloc_buffers"); + // rmm::mr::pinned_memory_resource mr_pinned; + // auto out_tmp1 = make_host_mdarray(res, mr_pinned, make_extents(max_batch_size, + // n_dim)); auto out_tmp2 = make_host_mdarray(res, mr_pinned, + // make_extents(max_batch_size, n_dim)); + auto out_tmp1 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + auto out_tmp2 = raft::make_pinned_matrix(res, max_batch_size, n_dim); + auto view1 = out_tmp1.view(); + auto view2 = out_tmp2.view(); + raft::common::nvtx::pop_range(); + + gather_buff(dataset, make_const_mdspan(indices_host.view()), (IdxT)0, view1); +#pragma omp parallel + for (IdxT device_offset = 0; device_offset < n_train; device_offset += max_batch_size) { + IdxT batch_size = std::min(max_batch_size, n_train - device_offset); +#pragma omp master + raft::copy(output.data_handle() + device_offset * n_dim, + view1.data_handle(), + batch_size * n_dim, + resource::get_cuda_stream(res)); + // Start gathering the next batch on the host. + IdxT host_offset = device_offset + batch_size; + batch_size = std::min(max_batch_size, n_train - host_offset); + if (batch_size > 0) { + gather_buff(dataset, make_const_mdspan(indices_host.view()), host_offset, view2); + } +#pragma omp master + resource::sync_stream(res); +#pragma omp barrier + std::swap(view1, view2); + } +} + } // namespace detail } // namespace matrix } // namespace raft diff --git a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh index bbe4c081e2..bd25506d44 100644 --- a/cpp/include/raft/spatial/knn/detail/ann_utils.cuh +++ b/cpp/include/raft/spatial/knn/detail/ann_utils.cuh @@ -16,13 +16,15 @@ #pragma once +#include #include #include #include #include + #include #include -#include +#include #include #include #include @@ -601,10 +603,6 @@ auto get_subsample_indices(raft::resources const& res, IdxT n_samples, IdxT n_su std::nullopt, train_indices.view(), std::nullopt); - - thrust::sort(resource::get_thrust_policy(res), - train_indices.data_handle(), - train_indices.data_handle() + n_subsamples); return train_indices; } @@ -618,12 +616,7 @@ void subsample(raft::resources const& res, { IdxT n_dim = output.extent(1); IdxT n_train = output.extent(0); - if (n_train == n_samples) { - RAFT_LOG_INFO("No subsampling"); - raft::copy(output.data_handle(), input, n_dim * n_samples, resource::get_cuda_stream(res)); - return; - } - RAFT_LOG_DEBUG("Random subsampling"); + raft::device_vector train_indices = get_subsample_indices(res, n_samples, n_train, seed); @@ -631,29 +624,13 @@ void subsample(raft::resources const& res, RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, input)); T* ptr = reinterpret_cast(attr.devicePointer); if (ptr != nullptr) { - raft::matrix::copy_rows(res, - raft::make_device_matrix_view(ptr, n_samples, n_dim), - output, - raft::make_const_mdspan(train_indices.view())); + raft::matrix::gather(res, + raft::make_device_matrix_view(ptr, n_samples, n_dim), + raft::make_const_mdspan(train_indices.view()), + output); } else { - auto dataset = raft::make_host_matrix_view(input, n_samples, n_dim); - auto train_indices_host = raft::make_host_vector(n_train); - raft::copy(train_indices_host.data_handle(), - train_indices.data_handle(), - n_train, - resource::get_cuda_stream(res)); - resource::sync_stream(res); - auto out_tmp = raft::make_host_matrix(n_train, n_dim); -#pragma omp parallel for - for (IdxT i = 0; i < n_train; i++) { - IdxT in_idx = train_indices_host(i); - for (IdxT k = 0; k < n_dim; k++) { - out_tmp(i, k) = dataset(in_idx, k); - } - } - raft::copy( - output.data_handle(), out_tmp.data_handle(), output.size(), resource::get_cuda_stream(res)); - resource::sync_stream(res); + auto dataset = raft::make_host_matrix_view(input, n_samples, n_dim); + raft::matrix::detail::gather(res, dataset, make_const_mdspan(train_indices.view()), output); } } } // namespace raft::spatial::knn::detail::utils