From cb11327ab4b9f7d2735100f503af03e32791ed56 Mon Sep 17 00:00:00 2001 From: achirkin Date: Tue, 12 Mar 2024 11:33:11 +0100 Subject: [PATCH] Further simplify deserialization code --- .../detail/cagra/cagra_serialize.cuh | 4 +- .../neighbors/detail/dataset_serialize.hpp | 103 ++++-------------- 2 files changed, 24 insertions(+), 83 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index f5556b256a..cb8073d8f6 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -249,9 +249,7 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index idx.update_graph(res, raft::make_const_mdspan(graph.view())); bool has_dataset = deserialize_scalar(res, is); if (has_dataset) { - std::unique_ptr> dataset; - neighbors::detail::deserialize(res, is, dataset); - idx.update_dataset(res, std::move(dataset)); + idx.update_dataset(res, neighbors::detail::deserialize_dataset(res, is)); } return idx; } diff --git a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp index dc60a4782d..dc55d891be 100644 --- a/cpp/include/raft/neighbors/detail/dataset_serialize.hpp +++ b/cpp/include/raft/neighbors/detail/dataset_serialize.hpp @@ -122,51 +122,28 @@ void serialize(const raft::resources& res, std::ostream& os, const dataset } template -void deserialize(raft::resources const& res, - std::istream& is, - std::unique_ptr>& out) +auto deserialize_empty(raft::resources const& res, std::istream& is) + -> std::unique_ptr> { auto suggested_dim = deserialize_scalar(res, is); - out = std::make_unique>(suggested_dim); + return std::make_unique>(suggested_dim); } template -void deserialize(raft::resources const& res, - std::istream& is, - std::unique_ptr>& out) +auto deserialize_strided(raft::resources const& res, std::istream& is) + -> std::unique_ptr> { - auto n_rows = deserialize_scalar(res, is); - auto dim = deserialize_scalar(res, is); - auto stride = deserialize_scalar(res, is); - auto out_extents = make_extents(n_rows, dim); - auto out_layout = make_strided_layout(out_extents, std::array{stride, 1}); - auto out_array = make_device_matrix(res, n_rows, stride); - - using out_mdarray_type = decltype(out_array); - using out_layout_type = typename out_mdarray_type::layout_type; - using out_container_policy_type = typename out_mdarray_type::container_policy_type; - using out_owning_type = owning_dataset; - - auto host_arrray = make_host_matrix(n_rows, dim); - deserialize_mdspan(res, is, host_arrray.view()); - RAFT_CUDA_TRY(cudaMemsetAsync( - out_array.data_handle(), 0, sizeof(DataT) * out_array.size(), resource::get_cuda_stream(res))); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(), - sizeof(DataT) * stride, - host_arrray.data_handle(), - sizeof(DataT) * dim, - sizeof(DataT) * dim, - n_rows, - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - - out = std::make_unique(std::move(out_array), out_layout); + auto n_rows = deserialize_scalar(res, is); + auto dim = deserialize_scalar(res, is); + auto stride = deserialize_scalar(res, is); + auto host_array = make_host_matrix(n_rows, dim); + deserialize_mdspan(res, is, host_array.view()); + return construct_strided_dataset(res, host_array, stride); } template -void deserialize(raft::resources const& res, - std::istream& is, - std::unique_ptr>& out) +auto deserialize_vpq(raft::resources const& res, std::istream& is) + -> std::unique_ptr> { auto n_rows = deserialize_scalar(res, is); auto dim = deserialize_scalar(res, is); @@ -183,62 +160,28 @@ void deserialize(raft::resources const& res, deserialize_mdspan(res, is, pq_code_book.view()); deserialize_mdspan(res, is, data.view()); - out = std::make_unique>( + return std::make_unique>( std::move(vq_code_book), std::move(pq_code_book), std::move(data)); } template -void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr>& out) +auto deserialize_dataset(raft::resources const& res, std::istream& is) + -> std::unique_ptr> { switch (deserialize_scalar(res, is)) { - case kSerializeEmptyDataset: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } + case kSerializeEmptyDataset: return deserialize_empty(res, is); case kSerializeStridedDataset: switch (deserialize_scalar(res, is)) { - case CUDA_R_32F: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } - case CUDA_R_16F: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } - case CUDA_R_8I: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } - case CUDA_R_8U: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } + case CUDA_R_32F: return deserialize_strided(res, is); + case CUDA_R_16F: return deserialize_strided(res, is); + case CUDA_R_8I: return deserialize_strided(res, is); + case CUDA_R_8U: return deserialize_strided(res, is); default: break; } case kSerializeVPQDataset: switch (deserialize_scalar(res, is)) { - case CUDA_R_32F: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } - case CUDA_R_16F: { - std::unique_ptr> p; - deserialize(res, is, p); - out = std::move(p); - return; - } + case CUDA_R_32F: return deserialize_vpq(res, is); + case CUDA_R_16F: return deserialize_vpq(res, is); default: break; } default: break;