From 0e0c5f3b64803ab145b390277186989baffccaf9 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Wed, 30 Aug 2023 16:12:52 -0700 Subject: [PATCH] revert serialization change was still making a copy from device->host inside serialize_mdspan, and with the include_dataset changes this branch won't even be called --- .../detail/cagra/cagra_serialize.cuh | 35 +++++++------------ 1 file changed, 12 insertions(+), 23 deletions(-) diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index d0e1f8ea4f..2c9cbd2563 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -71,29 +71,18 @@ void serialize(raft::resources const& res, serialize_scalar(res, os, include_dataset); if (include_dataset) { auto dataset = index_.dataset(); - if (dataset.stride(0) == dataset.extent(0)) { - // Rather than take another copy of the dataset here, just write it out directly. - // Since the dataset is a strided layout, we can't pass directly to the serialize_mdspan - // - but since the stride is the same as the extent, we can convert to a row-major - // mdspan - serialize_mdspan( - res, - os, - make_device_matrix_view(dataset.data_handle(), dataset.extent(0), dataset.extent(1))); - } else { - // Remove padding before saving the dataset - auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - serialize_mdspan(res, os, host_dataset.view()); - } + // Remove padding before saving the dataset + auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), + sizeof(T) * host_dataset.extent(1), + dataset.data_handle(), + sizeof(T) * dataset.stride(0), + sizeof(T) * host_dataset.extent(1), + dataset.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + resource::sync_stream(res); + serialize_mdspan(res, os, host_dataset.view()); } }