Skip to content

Commit

Permalink
Fix deserialization: set the padding bytes to zero in the strided dat…
Browse files Browse the repository at this point in the history
…aset.
  • Loading branch information
achirkin committed Mar 11, 2024
1 parent 292406c commit 24ebae2
Showing 1 changed file with 14 additions and 9 deletions.
23 changes: 14 additions & 9 deletions cpp/include/raft/neighbors/detail/dataset_serialize.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -44,18 +44,21 @@ void serialize(const raft::resources& res,
std::ostream& os,
const strided_dataset<DataT, IdxT>& dataset)
{
serialize_scalar(res, os, dataset.n_rows());
serialize_scalar(res, os, dataset.dim());
serialize_scalar(res, os, dataset.stride());
auto n_rows = dataset.n_rows();
auto dim = dataset.dim();
auto stride = dataset.stride();
serialize_scalar(res, os, n_rows);
serialize_scalar(res, os, dim);
serialize_scalar(res, os, stride);
// Remove padding before saving the dataset
auto src = dataset.view();
auto dst = make_host_mdarray<DataT, IdxT>(src.extents());
auto dst = make_host_matrix<DataT, IdxT>(n_rows, dim);
RAFT_CUDA_TRY(cudaMemcpy2DAsync(dst.data_handle(),
sizeof(DataT) * dst.extent(1),
sizeof(DataT) * dim,
src.data_handle(),
sizeof(DataT) * src.stride(0),
sizeof(DataT) * dst.extent(1),
src.extent(0),
sizeof(DataT) * stride,
sizeof(DataT) * dim,
n_rows,
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
Expand Down Expand Up @@ -144,8 +147,10 @@ void deserialize(raft::resources const& res,
using out_container_policy_type = typename out_mdarray_type::container_policy_type;
using out_owning_type = owning_dataset<DataT, IdxT, out_layout_type, out_container_policy_type>;

auto host_arrray = make_host_mdarray<DataT, IdxT>(out_extents);
auto host_arrray = make_host_matrix<DataT, IdxT>(n_rows, dim);
deserialize_mdspan(res, is, host_arrray.view());
RAFT_CUDA_TRY(cudaMemsetAsync(
out_array.data_handle(), 0, sizeof(DataT) * out_array.size(), resource::get_cuda_stream(res)));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(out_array.data_handle(),
sizeof(DataT) * stride,
host_arrray.data_handle(),
Expand Down

0 comments on commit 24ebae2

Please sign in to comment.