Skip to content

Commit

Permalink
revert serialization change
Browse files Browse the repository at this point in the history
was still making a copy from device->host inside serialize_mdspan, and with the
include_dataset changes this branch won't even be called
  • Loading branch information
benfred committed Aug 30, 2023
1 parent f539655 commit 0e0c5f3
Showing 1 changed file with 12 additions and 23 deletions.
35 changes: 12 additions & 23 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -71,29 +71,18 @@ void serialize(raft::resources const& res,
serialize_scalar(res, os, include_dataset);
if (include_dataset) {
auto dataset = index_.dataset();
if (dataset.stride(0) == dataset.extent(0)) {
// Rather than take another copy of the dataset here, just write it out directly.
// Since the dataset is a strided layout, we can't pass directly to the serialize_mdspan
// - but since the stride is the same as the extent, we can convert to a row-major
// mdspan
serialize_mdspan(
res,
os,
make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1)));
} else {
// Remove padding before saving the dataset
auto host_dataset = make_host_matrix<T, int64_t>(dataset.extent(0), dataset.extent(1));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(),
sizeof(T) * host_dataset.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.stride(0),
sizeof(T) * host_dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
serialize_mdspan(res, os, host_dataset.view());
}
// Remove padding before saving the dataset
auto host_dataset = make_host_matrix<T, int64_t>(dataset.extent(0), dataset.extent(1));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(),
sizeof(T) * host_dataset.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.stride(0),
sizeof(T) * host_dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
serialize_mdspan(res, os, host_dataset.view());
}
}

Expand Down

0 comments on commit 0e0c5f3

Please sign in to comment.