diff --git a/cpp/include/cuvs/neighbors/common.hpp b/cpp/include/cuvs/neighbors/common.hpp index 60b8cc122..bd9ea4834 100644 --- a/cpp/include/cuvs/neighbors/common.hpp +++ b/cpp/include/cuvs/neighbors/common.hpp @@ -264,6 +264,77 @@ auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t return std::make_unique(std::move(out_array), out_layout); } +/** + * @brief Contstruct a strided matrix from any mdarray. + * + * This function constructs an owning device matrix and copies the data. + * When the data is copied, padding elements are filled with zeroes. + * + * @tparam DataT + * @tparam IdxT + * @tparam LayoutPolicy + * @tparam ContainerPolicy + * + * @param[in] res raft resources handle + * @param[in] src the source mdarray or mdspan + * @param[in] required_stride the leading dimension (in elements) + * @return owning current-device-accessible strided matrix + */ +template +auto make_strided_dataset( + const raft::resources& res, + raft::mdarray, LayoutPolicy, ContainerPolicy>&& src, + uint32_t required_stride) -> std::unique_ptr> +{ + using value_type = DataT; + using index_type = IdxT; + using layout_type = LayoutPolicy; + using container_policy_type = ContainerPolicy; + static_assert(std::is_same_v || + std::is_same_v> || + std::is_same_v, + "The input must be row-major"); + RAFT_EXPECTS(src.extent(1) <= required_stride, + "The input row length must be not larger than the desired stride."); + const uint32_t src_stride = src.stride(0) > 0 ? src.stride(0) : src.extent(1); + const bool stride_matches = required_stride == src_stride; + + auto out_layout = + raft::make_strided_layout(src.extents(), std::array{required_stride, 1}); + + using out_mdarray_type = raft::device_matrix; + using out_layout_type = typename out_mdarray_type::layout_type; + using out_container_policy_type = typename out_mdarray_type::container_policy_type; + using out_owning_type = + owning_dataset; + + if constexpr (std::is_same_v && + std::is_same_v) { + if (stride_matches) { + // Everything matches, we can own the mdarray + return std::make_unique(std::move(src), out_layout); + } + } + // Something is wrong: have to make a copy and produce an owning dataset + auto out_array = + raft::make_device_matrix(res, src.extent(0), required_stride); + + RAFT_CUDA_TRY(cudaMemsetAsync(out_array.data_handle(), + 0, + out_array.size() * sizeof(value_type), + raft::resource::get_cuda_stream(res))); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), + sizeof(value_type) * required_stride, + src.data_handle(), + sizeof(value_type) * src_stride, + sizeof(value_type) * src.extent(1), + src.extent(0), + cudaMemcpyDefault, + raft::resource::get_cuda_stream(res))); + + return std::make_unique(std::move(out_array), out_layout); +} + /** * @brief Contstruct a strided matrix from any mdarray or mdspan. * @@ -278,14 +349,15 @@ auto make_strided_dataset(const raft::resources& res, const SrcT& src, uint32_t * @return maybe owning current-device-accessible strided matrix */ template -auto make_aligned_dataset(const raft::resources& res, const SrcT& src, uint32_t align_bytes = 16) +auto make_aligned_dataset(const raft::resources& res, SrcT src, uint32_t align_bytes = 16) -> std::unique_ptr> { - using value_type = typename SrcT::value_type; + using source_type = std::remove_cv_t>; + using value_type = typename source_type::value_type; constexpr size_t kSize = sizeof(value_type); uint32_t required_stride = raft::round_up_safe(src.extent(1) * kSize, std::lcm(align_bytes, kSize)) / kSize; - return make_strided_dataset(res, src, required_stride); + return make_strided_dataset(res, std::forward(src), required_stride); } /** * @brief VPQ compressed dataset. diff --git a/cpp/src/neighbors/detail/ann_utils.cuh b/cpp/src/neighbors/detail/ann_utils.cuh index 652d41c85..529356351 100644 --- a/cpp/src/neighbors/detail/ann_utils.cuh +++ b/cpp/src/neighbors/detail/ann_utils.cuh @@ -403,6 +403,17 @@ struct batch_load_iterator { /** A single batch of data residing in device memory. */ struct batch { + ~batch() noexcept + { + /* + If there's no copy, there's no allocation owned by the batch. + If there's no allocation, there's no guarantee that the device pointer is stream-ordered. + If there's no stream order guarantee, we must synchronize with the stream before the batch is + destroyed to make sure all GPU operations in that stream finish earlier. + */ + if (!does_copy()) { RAFT_CUDA_TRY_NO_THROW(cudaStreamSynchronize(stream_)); } + } + /** Logical width of a single row in a batch, in elements of type `T`. */ [[nodiscard]] auto row_width() const -> size_type { return row_width_; } /** Logical offset of the batch, in rows (`row_width()`) */ diff --git a/cpp/src/neighbors/detail/cagra/utils.hpp b/cpp/src/neighbors/detail/cagra/utils.hpp index 0f8309328..9f95c5b1c 100644 --- a/cpp/src/neighbors/detail/cagra/utils.hpp +++ b/cpp/src/neighbors/detail/cagra/utils.hpp @@ -179,7 +179,7 @@ class device_matrix_view_from_host { public: device_matrix_view_from_host(raft::resources const& res, raft::host_matrix_view host_view) - : host_view_(host_view) + : res_(res), host_view_(host_view) { cudaPointerAttributes attr; RAFT_CUDA_TRY(cudaPointerGetAttributes(&attr, host_view.data_handle())); @@ -199,6 +199,17 @@ class device_matrix_view_from_host { } } + ~device_matrix_view_from_host() noexcept + { + /* + If there's no copy, there's no allocation owned by this struct. + If there's no allocation, there's no guarantee that the device pointer is stream-ordered. + If there's no stream order guarantee, we must synchronize with the stream before the struct is + destroyed to make sure all GPU operations in that stream finish earlier. + */ + if (!allocated_memory()) { raft::resource::sync_stream(res_); } + } + raft::device_matrix_view view() { return raft::make_device_matrix_view( @@ -207,9 +218,10 @@ class device_matrix_view_from_host { T* data_handle() { return device_ptr; } - bool allocated_memory() const { return device_mem_.has_value(); } + [[nodiscard]] bool allocated_memory() const { return device_mem_.has_value(); } private: + const raft::resources& res_; std::optional> device_mem_; raft::host_matrix_view host_view_; T* device_ptr; diff --git a/cpp/src/neighbors/detail/dataset_serialize.hpp b/cpp/src/neighbors/detail/dataset_serialize.hpp index 40d9df930..0ecc2cf5d 100644 --- a/cpp/src/neighbors/detail/dataset_serialize.hpp +++ b/cpp/src/neighbors/detail/dataset_serialize.hpp @@ -140,7 +140,7 @@ auto deserialize_strided(raft::resources const& res, std::istream& is) auto stride = raft::deserialize_scalar(res, is); auto host_array = raft::make_host_matrix(n_rows, dim); raft::deserialize_mdspan(res, is, host_array.view()); - return make_strided_dataset(res, host_array, stride); + return make_strided_dataset(res, std::move(host_array), stride); } template diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 8d5701439..c1cd3ca09 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -389,12 +389,13 @@ class AnnCagraTest : public ::testing::TestWithParam { (const DataT*)database.data(), ps.n_rows, ps.dim); { + std::optional> database_host{std::nullopt}; cagra::index index(handle_, index_params.metric); if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host->data_handle(), database.data(), database.size(), stream_); auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + (const DataT*)database_host->data_handle(), ps.n_rows, ps.dim); index = cagra::build(handle_, index_params, database_host_view); } else { @@ -567,13 +568,16 @@ class AnnCagraAddNodesTest : public ::testing::TestWithParam { auto initial_database_view = raft::make_device_matrix_view( (const DataT*)database.data(), initial_database_size, ps.dim); + std::optional> database_host{std::nullopt}; cagra::index index(handle_); if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + database_host = raft::make_host_matrix(ps.n_rows, ps.dim); raft::copy( - database_host.data_handle(), database.data(), initial_database_view.size(), stream_); + database_host->data_handle(), database.data(), initial_database_view.size(), stream_); auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), initial_database_size, ps.dim); + (const DataT*)database_host->data_handle(), initial_database_size, ps.dim); + // NB: database_host must live no less than the index, because the index _may_be_ + // non-onwning index = cagra::build(handle_, index_params, database_host_view); } else { index = cagra::build(handle_, index_params, initial_database_view); @@ -763,12 +767,13 @@ class AnnCagraFilterTest : public ::testing::TestWithParam { auto database_view = raft::make_device_matrix_view( (const DataT*)database.data(), ps.n_rows, ps.dim); + std::optional> database_host{std::nullopt}; cagra::index index(handle_); if (ps.host_dataset) { - auto database_host = raft::make_host_matrix(ps.n_rows, ps.dim); - raft::copy(database_host.data_handle(), database.data(), database.size(), stream_); + database_host = raft::make_host_matrix(ps.n_rows, ps.dim); + raft::copy(database_host->data_handle(), database.data(), database.size(), stream_); auto database_host_view = raft::make_host_matrix_view( - (const DataT*)database_host.data_handle(), ps.n_rows, ps.dim); + (const DataT*)database_host->data_handle(), ps.n_rows, ps.dim); index = cagra::build(handle_, index_params, database_host_view); } else { index = cagra::build(handle_, index_params, database_view);