Skip to content

Commit

Permalink
Further simplify deserialization code
Browse files Browse the repository at this point in the history
  • Loading branch information
achirkin committed Mar 12, 2024
1 parent 24ebae2 commit cb11327
Show file tree
Hide file tree
Showing 2 changed files with 24 additions and 83 deletions.
4 changes: 1 addition & 3 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -249,9 +249,7 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
idx.update_graph(res, raft::make_const_mdspan(graph.view()));
bool has_dataset = deserialize_scalar<bool>(res, is);
if (has_dataset) {
std::unique_ptr<dataset<int64_t>> dataset;
neighbors::detail::deserialize(res, is, dataset);
idx.update_dataset(res, std::move(dataset));
idx.update_dataset(res, neighbors::detail::deserialize_dataset<int64_t>(res, is));
}
return idx;
}
Expand Down
103 changes: 23 additions & 80 deletions cpp/include/raft/neighbors/detail/dataset_serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -122,51 +122,28 @@ void serialize(const raft::resources& res, std::ostream& os, const dataset<IdxT>
}

template <typename IdxT>
void deserialize(raft::resources const& res,
std::istream& is,
std::unique_ptr<empty_dataset<IdxT>>& out)
auto deserialize_empty(raft::resources const& res, std::istream& is)
-> std::unique_ptr<empty_dataset<IdxT>>
{
auto suggested_dim = deserialize_scalar<uint32_t>(res, is);
out = std::make_unique<empty_dataset<IdxT>>(suggested_dim);
return std::make_unique<empty_dataset<IdxT>>(suggested_dim);
}

template <typename DataT, typename IdxT>
void deserialize(raft::resources const& res,
std::istream& is,
std::unique_ptr<strided_dataset<DataT, IdxT>>& out)
auto deserialize_strided(raft::resources const& res, std::istream& is)
-> std::unique_ptr<strided_dataset<DataT, IdxT>>
{
auto n_rows = deserialize_scalar<IdxT>(res, is);
auto dim = deserialize_scalar<uint32_t>(res, is);
auto stride = deserialize_scalar<uint32_t>(res, is);
auto out_extents = make_extents<IdxT>(n_rows, dim);
auto out_layout = make_strided_layout(out_extents, std::array<IdxT, 2>{stride, 1});
auto out_array = make_device_matrix<DataT, IdxT>(res, n_rows, stride);

using out_mdarray_type = decltype(out_array);
using out_layout_type = typename out_mdarray_type::layout_type;
using out_container_policy_type = typename out_mdarray_type::container_policy_type;
using out_owning_type = owning_dataset<DataT, IdxT, out_layout_type, out_container_policy_type>;

auto host_arrray = make_host_matrix<DataT, IdxT>(n_rows, dim);
deserialize_mdspan(res, is, host_arrray.view());
RAFT_CUDA_TRY(cudaMemsetAsync(
out_array.data_handle(), 0, sizeof(DataT) * out_array.size(), resource::get_cuda_stream(res)));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(),
sizeof(DataT) * stride,
host_arrray.data_handle(),
sizeof(DataT) * dim,
sizeof(DataT) * dim,
n_rows,
cudaMemcpyDefault,
resource::get_cuda_stream(res)));

out = std::make_unique<out_owning_type>(std::move(out_array), out_layout);
auto n_rows = deserialize_scalar<IdxT>(res, is);
auto dim = deserialize_scalar<uint32_t>(res, is);
auto stride = deserialize_scalar<uint32_t>(res, is);
auto host_array = make_host_matrix<DataT, IdxT>(n_rows, dim);
deserialize_mdspan(res, is, host_array.view());
return construct_strided_dataset(res, host_array, stride);
}

template <typename MathT, typename IdxT>
void deserialize(raft::resources const& res,
std::istream& is,
std::unique_ptr<vpq_dataset<MathT, IdxT>>& out)
auto deserialize_vpq(raft::resources const& res, std::istream& is)
-> std::unique_ptr<vpq_dataset<MathT, IdxT>>
{
auto n_rows = deserialize_scalar<IdxT>(res, is);
auto dim = deserialize_scalar<uint32_t>(res, is);
Expand All @@ -183,62 +160,28 @@ void deserialize(raft::resources const& res,
deserialize_mdspan(res, is, pq_code_book.view());
deserialize_mdspan(res, is, data.view());

out = std::make_unique<vpq_dataset<MathT, IdxT>>(
return std::make_unique<vpq_dataset<MathT, IdxT>>(
std::move(vq_code_book), std::move(pq_code_book), std::move(data));
}

template <typename IdxT>
void deserialize(raft::resources const& res, std::istream& is, std::unique_ptr<dataset<IdxT>>& out)
auto deserialize_dataset(raft::resources const& res, std::istream& is)
-> std::unique_ptr<dataset<IdxT>>
{
switch (deserialize_scalar<dataset_instance_tag>(res, is)) {
case kSerializeEmptyDataset: {
std::unique_ptr<empty_dataset<IdxT>> p;
deserialize(res, is, p);
out = std::move(p);
return;
}
case kSerializeEmptyDataset: return deserialize_empty<IdxT>(res, is);
case kSerializeStridedDataset:
switch (deserialize_scalar<cudaDataType_t>(res, is)) {
case CUDA_R_32F: {
std::unique_ptr<strided_dataset<float, IdxT>> p;
deserialize(res, is, p);
out = std::move(p);
return;
}
case CUDA_R_16F: {
std::unique_ptr<strided_dataset<half, IdxT>> p;
deserialize(res, is, p);
out = std::move(p);
return;
}
case CUDA_R_8I: {
std::unique_ptr<strided_dataset<int8_t, IdxT>> p;
deserialize(res, is, p);
out = std::move(p);
return;
}
case CUDA_R_8U: {
std::unique_ptr<strided_dataset<uint8_t, IdxT>> p;
deserialize(res, is, p);
out = std::move(p);
return;
}
case CUDA_R_32F: return deserialize_strided<float, IdxT>(res, is);
case CUDA_R_16F: return deserialize_strided<half, IdxT>(res, is);
case CUDA_R_8I: return deserialize_strided<int8_t, IdxT>(res, is);
case CUDA_R_8U: return deserialize_strided<uint8_t, IdxT>(res, is);
default: break;
}
case kSerializeVPQDataset:
switch (deserialize_scalar<cudaDataType_t>(res, is)) {
case CUDA_R_32F: {
std::unique_ptr<vpq_dataset<float, IdxT>> p;
deserialize(res, is, p);
out = std::move(p);
return;
}
case CUDA_R_16F: {
std::unique_ptr<vpq_dataset<half, IdxT>> p;
deserialize(res, is, p);
out = std::move(p);
return;
}
case CUDA_R_32F: return deserialize_vpq<float, IdxT>(res, is);
case CUDA_R_16F: return deserialize_vpq<half, IdxT>(res, is);
default: break;
}
default: break;
Expand Down

0 comments on commit cb11327

Please sign in to comment.