Skip to content

Commit

Permalink
Update CAGRA serialization (#1755)
Browse files Browse the repository at this point in the history
This changes the serialization format of saved CAGRA indices by:

* The dtype will now be written in the first 4 bytes of the serialized file, to match the IVF methods and to make it easier to deduce the dtype from python (#1729)
* Writing out the dataset with the index is now optional. Since many use cases will already have the dataset written out separately, this gives us the option to save disk space by not writing out an extra copy of the input dataset. If the include_dataset=false option is given, you will have to call `index.update_dataset` to set the dataset yourself after loading

Authors:
  - Ben Frederickson (https://github.com/benfred)

Approvers:
  - Corey J. Nolet (https://github.com/cjnolet)

URL: #1755
  • Loading branch information
benfred authored Aug 21, 2023
1 parent 7aded12 commit ea9d395
Show file tree
Hide file tree
Showing 10 changed files with 284 additions and 52 deletions.
16 changes: 11 additions & 5 deletions cpp/include/raft/neighbors/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -47,12 +47,16 @@ namespace raft::neighbors::cagra {
* @param[in] handle the raft handle
* @param[in] os output stream
* @param[in] index CAGRA index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*
*/
template <typename T, typename IdxT>
void serialize(raft::resources const& handle, std::ostream& os, const index<T, IdxT>& index)
void serialize(raft::resources const& handle,
std::ostream& os,
const index<T, IdxT>& index,
bool include_dataset = true)
{
detail::serialize(handle, os, index);
detail::serialize(handle, os, index, include_dataset);
}

/**
Expand All @@ -77,14 +81,16 @@ void serialize(raft::resources const& handle, std::ostream& os, const index<T, I
* @param[in] handle the raft handle
* @param[in] filename the file name for saving the index
* @param[in] index CAGRA index
* @param[in] include_dataset Whether or not to write out the dataset to the file.
*
*/
template <typename T, typename IdxT>
void serialize(raft::resources const& handle,
const std::string& filename,
const index<T, IdxT>& index)
const index<T, IdxT>& index,
bool include_dataset = true)
{
detail::serialize(handle, filename, index);
detail::serialize(handle, filename, index, include_dataset);
}

/**
Expand Down Expand Up @@ -158,4 +164,4 @@ namespace raft::neighbors::experimental::cagra {
using raft::neighbors::cagra::deserialize;
using raft::neighbors::cagra::serialize;

} // namespace raft::neighbors::experimental::cagra
} // namespace raft::neighbors::experimental::cagra
54 changes: 35 additions & 19 deletions cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -24,8 +24,7 @@

namespace raft::neighbors::cagra::detail {

// Serialization version 1.
constexpr int serialization_version = 2;
constexpr int serialization_version = 3;

// NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error
// message.
Expand All @@ -50,41 +49,53 @@ template struct check_index_layout<sizeof(index<double, std::uint64_t>), expecte
*
*/
template <typename T, typename IdxT>
void serialize(raft::resources const& res, std::ostream& os, const index<T, IdxT>& index_)
void serialize(raft::resources const& res,
std::ostream& os,
const index<T, IdxT>& index_,
bool include_dataset)
{
RAFT_LOG_DEBUG(
"Saving CAGRA index, size %zu, dim %u", static_cast<size_t>(index_.size()), index_.dim());

std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype<T>().to_string();
dtype_string.resize(4);
os << dtype_string;

serialize_scalar(res, os, serialization_version);
serialize_scalar(res, os, index_.size());
serialize_scalar(res, os, index_.dim());
serialize_scalar(res, os, index_.graph_degree());
serialize_scalar(res, os, index_.metric());
auto dataset = index_.dataset();
// Remove padding before saving the dataset
auto host_dataset = make_host_matrix<T, int64_t>(dataset.extent(0), dataset.extent(1));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(),
sizeof(T) * host_dataset.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.stride(0),
sizeof(T) * host_dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
serialize_mdspan(res, os, host_dataset.view());
serialize_mdspan(res, os, index_.graph());

serialize_scalar(res, os, include_dataset);
if (include_dataset) {
auto dataset = index_.dataset();
// Remove padding before saving the dataset
auto host_dataset = make_host_matrix<T, int64_t>(dataset.extent(0), dataset.extent(1));
RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(),
sizeof(T) * host_dataset.extent(1),
dataset.data_handle(),
sizeof(T) * dataset.stride(0),
sizeof(T) * host_dataset.extent(1),
dataset.extent(0),
cudaMemcpyDefault,
resource::get_cuda_stream(res)));
resource::sync_stream(res);
serialize_mdspan(res, os, host_dataset.view());
}
}

template <typename T, typename IdxT>
void serialize(raft::resources const& res,
const std::string& filename,
const index<T, IdxT>& index_)
const index<T, IdxT>& index_,
bool include_dataset)
{
std::ofstream of(filename, std::ios::out | std::ios::binary);
if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); }

detail::serialize(res, of, index_);
detail::serialize(res, of, index_, include_dataset);

of.close();
if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); }
Expand All @@ -102,6 +113,9 @@ void serialize(raft::resources const& res,
template <typename T, typename IdxT>
auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>
{
char dtype_string[4];
is.read(dtype_string, 4);

auto ver = deserialize_scalar<int>(res, is);
if (ver != serialization_version) {
RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver);
Expand All @@ -113,9 +127,11 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index<T, IdxT>

auto dataset = raft::make_host_matrix<T, int64_t>(n_rows, dim);
auto graph = raft::make_host_matrix<IdxT, int64_t>(n_rows, graph_degree);
deserialize_mdspan(res, is, dataset.view());
deserialize_mdspan(res, is, graph.view());

bool has_dataset = deserialize_scalar<bool>(res, is);
if (has_dataset) { deserialize_mdspan(res, is, dataset.view()); }

return index<T, IdxT>(
res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view()));
}
Expand Down
6 changes: 4 additions & 2 deletions cpp/include/raft_runtime/neighbors/cagra.hpp
Original file line number Diff line number Diff line change
Expand Up @@ -56,14 +56,16 @@ namespace raft::runtime::neighbors::cagra {
raft::device_matrix_view<float, int64_t, row_major> distances); \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<T, IdxT>& index); \
const raft::neighbors::cagra::index<T, IdxT>& index, \
bool include_dataset = true); \
\
void deserialize_file(raft::resources const& handle, \
const std::string& filename, \
raft::neighbors::cagra::index<T, IdxT>* index); \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<T, IdxT>& index); \
const raft::neighbors::cagra::index<T, IdxT>& index, \
bool include_dataset = true); \
\
void deserialize(raft::resources const& handle, \
const std::string& str, \
Expand Down
10 changes: 6 additions & 4 deletions cpp/src/raft_runtime/neighbors/cagra_serialize.cu
Original file line number Diff line number Diff line change
Expand Up @@ -27,9 +27,10 @@ namespace raft::runtime::neighbors::cagra {
#define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \
void serialize_file(raft::resources const& handle, \
const std::string& filename, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index) \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
raft::neighbors::cagra::serialize(handle, filename, index); \
raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \
}; \
\
void deserialize_file(raft::resources const& handle, \
Expand All @@ -41,10 +42,11 @@ namespace raft::runtime::neighbors::cagra {
}; \
void serialize(raft::resources const& handle, \
std::string& str, \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index) \
const raft::neighbors::cagra::index<DTYPE, uint32_t>& index, \
bool include_dataset) \
{ \
std::stringstream os; \
raft::neighbors::cagra::serialize(handle, os, index); \
raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \
str = os.str(); \
} \
\
Expand Down
13 changes: 10 additions & 3 deletions cpp/test/neighbors/ann_cagra.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -137,6 +137,7 @@ struct AnnCagraInputs {
int search_width;
raft::distance::DistanceType metric;
bool host_dataset;
bool include_serialized_dataset;
// std::optional<double>
double min_recall; // = std::nullopt;
};
Expand Down Expand Up @@ -217,9 +218,11 @@ class AnnCagraTest : public ::testing::TestWithParam<AnnCagraInputs> {
} else {
index = cagra::build<DataT, IdxT>(handle_, index_params, database_view);
};
cagra::serialize(handle_, "cagra_index", index);
cagra::serialize(handle_, "cagra_index", index, ps.include_serialized_dataset);
}

auto index = cagra::deserialize<DataT, IdxT>(handle_, "cagra_index");
if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); }

auto search_queries_view = raft::make_device_matrix_view<const DataT, int64_t>(
search_queries.data(), ps.n_queries, ps.dim);
Expand Down Expand Up @@ -340,9 +343,7 @@ class AnnCagraSortTest : public ::testing::TestWithParam<AnnCagraInputs> {

void SetUp() override
{
std::cout << "Resizing database: " << ps.n_rows * ps.dim << std::endl;
database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream());
std::cout << "Done.\nRuning rng" << std::endl;
raft::random::Rng r(1234ULL);
if constexpr (std::is_same<DataT, float>{}) {
GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, handle_.get_stream());
Expand Down Expand Up @@ -379,6 +380,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false},
{true},
{0.995});

auto inputs2 = raft::util::itertools::product<AnnCagraInputs>(
Expand All @@ -393,6 +395,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false},
{true},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());
inputs2 =
Expand All @@ -407,6 +410,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false},
{false},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

Expand All @@ -422,6 +426,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false},
{true},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

Expand All @@ -437,6 +442,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false, true},
{false},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

Expand All @@ -452,6 +458,7 @@ inline std::vector<AnnCagraInputs> generate_inputs()
{1},
{raft::distance::DistanceType::L2Expanded},
{false, true},
{true},
{0.995});
inputs.insert(inputs.end(), inputs2.begin(), inputs2.end());

Expand Down
24 changes: 24 additions & 0 deletions python/pylibraft/pylibraft/common/mdspan.pxd
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,12 @@ from pylibraft.common.cpp.mdspan cimport (
from pylibraft.common.handle cimport device_resources
from pylibraft.common.optional cimport make_optional, optional

# Cython doesn't like `const float` inside template parameters
# hack around this with using typedefs
ctypedef const float const_float
ctypedef const int8_t const_int8_t
ctypedef const uint8_t const_uint8_t


cdef device_matrix_view[float, int64_t, row_major] get_dmv_float(
array, check_shape) except *
Expand All @@ -49,6 +55,15 @@ cdef optional[device_matrix_view[int64_t, int64_t, row_major]] make_optional_vie
cdef device_matrix_view[uint32_t, int64_t, row_major] get_dmv_uint32(
array, check_shape) except *

cdef device_matrix_view[const_float, int64_t, row_major] get_const_dmv_float(
array, check_shape) except *

cdef device_matrix_view[const_uint8_t, int64_t, row_major] get_const_dmv_uint8(
array, check_shape) except *

cdef device_matrix_view[const_int8_t, int64_t, row_major] get_const_dmv_int8(
array, check_shape) except *

cdef host_matrix_view[float, int64_t, row_major] get_hmv_float(
array, check_shape) except *

Expand All @@ -63,3 +78,12 @@ cdef host_matrix_view[int64_t, int64_t, row_major] get_hmv_int64(

cdef host_matrix_view[uint32_t, int64_t, row_major] get_hmv_uint32(
array, check_shape) except *

cdef host_matrix_view[const_float, int64_t, row_major] get_const_hmv_float(
array, check_shape) except *

cdef host_matrix_view[const_uint8_t, int64_t, row_major] get_const_hmv_uint8(
array, check_shape) except *

cdef host_matrix_view[const_int8_t, int64_t, row_major] get_const_hmv_int8(
array, check_shape) except *
67 changes: 66 additions & 1 deletion python/pylibraft/pylibraft/common/mdspan.pyx
Original file line number Diff line number Diff line change
Expand Up @@ -193,6 +193,39 @@ cdef device_matrix_view[int64_t, int64_t, row_major] \
<int64_t*><uintptr_t>cai.data, shape[0], shape[1])


cdef device_matrix_view[const_float, int64_t, row_major] \
get_const_dmv_float(cai, check_shape) except *:
if cai.dtype != np.float32:
raise TypeError("dtype %s not supported" % cai.dtype)
if check_shape and len(cai.shape) != 2:
raise ValueError("Expected a 2D array, got %d D" % len(cai.shape))
shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1)
return make_device_matrix_view[const_float, int64_t, row_major](
<const float*><uintptr_t>cai.data, shape[0], shape[1])


cdef device_matrix_view[const_uint8_t, int64_t, row_major] \
get_const_dmv_uint8(cai, check_shape) except *:
if cai.dtype != np.uint8:
raise TypeError("dtype %s not supported" % cai.dtype)
if check_shape and len(cai.shape) != 2:
raise ValueError("Expected a 2D array, got %d D" % len(cai.shape))
shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1)
return make_device_matrix_view[const_uint8_t, int64_t, row_major](
<const uint8_t*><uintptr_t>cai.data, shape[0], shape[1])


cdef device_matrix_view[const_int8_t, int64_t, row_major] \
get_const_dmv_int8(cai, check_shape) except *:
if cai.dtype != np.int8:
raise TypeError("dtype %s not supported" % cai.dtype)
if check_shape and len(cai.shape) != 2:
raise ValueError("Expected a 2D array, got %d D" % len(cai.shape))
shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1)
return make_device_matrix_view[const_int8_t, int64_t, row_major](
<const int8_t*><uintptr_t>cai.data, shape[0], shape[1])


cdef optional[device_matrix_view[int64_t, int64_t, row_major]] \
make_optional_view_int64(device_matrix_view[int64_t, int64_t, row_major]& dmv) except *: # noqa: E501
return make_optional[device_matrix_view[int64_t, int64_t, row_major]](dmv)
Expand Down Expand Up @@ -222,7 +255,6 @@ cdef host_matrix_view[float, int64_t, row_major] \
return make_host_matrix_view[float, int64_t, row_major](
<float*><uintptr_t>cai.data, shape[0], shape[1])


cdef host_matrix_view[uint8_t, int64_t, row_major] \
get_hmv_uint8(cai, check_shape) except *:
if cai.dtype != np.uint8:
Expand Down Expand Up @@ -265,3 +297,36 @@ cdef host_matrix_view[uint32_t, int64_t, row_major] \
shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1)
return make_host_matrix_view[uint32_t, int64_t, row_major](
<uint32_t*><uintptr_t>cai.data, shape[0], shape[1])


cdef host_matrix_view[const_float, int64_t, row_major] \
get_const_hmv_float(cai, check_shape) except *:
if cai.dtype != np.float32:
raise TypeError("dtype %s not supported" % cai.dtype)
if check_shape and len(cai.shape) != 2:
raise ValueError("Expected a 2D array, got %d D" % len(cai.shape))
shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1)
return make_host_matrix_view[const_float, int64_t, row_major](
<const float*><uintptr_t>cai.data, shape[0], shape[1])


cdef host_matrix_view[const_uint8_t, int64_t, row_major] \
get_const_hmv_uint8(cai, check_shape) except *:
if cai.dtype != np.uint8:
raise TypeError("dtype %s not supported" % cai.dtype)
if check_shape and len(cai.shape) != 2:
raise ValueError("Expected a 2D array, got %d D" % len(cai.shape))
shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1)
return make_host_matrix_view[const_uint8_t, int64_t, row_major](
<const uint8_t*><uintptr_t>cai.data, shape[0], shape[1])


cdef host_matrix_view[const_int8_t, int64_t, row_major] \
get_const_hmv_int8(cai, check_shape) except *:
if cai.dtype != np.int8:
raise TypeError("dtype %s not supported" % cai.dtype)
if check_shape and len(cai.shape) != 2:
raise ValueError("Expected a 2D array, got %d D" % len(cai.shape))
shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1)
return make_host_matrix_view[const_int8_t, int64_t, row_major](
<const_int8_t*><uintptr_t>cai.data, shape[0], shape[1])
Loading

0 comments on commit ea9d395

Please sign in to comment.