From ea9d39565cc6d5496e240a88da301e5be15769b8 Mon Sep 17 00:00:00 2001 From: Ben Frederickson Date: Mon, 21 Aug 2023 14:05:19 -0700 Subject: [PATCH] Update CAGRA serialization (#1755) This changes the serialization format of saved CAGRA indices by: * The dtype will now be written in the first 4 bytes of the serialized file, to match the IVF methods and to make it easier to deduce the dtype from python (#1729) * Writing out the dataset with the index is now optional. Since many use cases will already have the dataset written out separately, this gives us the option to save disk space by not writing out an extra copy of the input dataset. If the include_dataset=false option is given, you will have to call `index.update_dataset` to set the dataset yourself after loading Authors: - Ben Frederickson (https://github.com/benfred) Approvers: - Corey J. Nolet (https://github.com/cjnolet) URL: https://github.com/rapidsai/raft/pull/1755 --- .../raft/neighbors/cagra_serialize.cuh | 16 ++- .../detail/cagra/cagra_serialize.cuh | 54 +++++---- cpp/include/raft_runtime/neighbors/cagra.hpp | 6 +- .../raft_runtime/neighbors/cagra_serialize.cu | 10 +- cpp/test/neighbors/ann_cagra.cuh | 13 ++- python/pylibraft/pylibraft/common/mdspan.pxd | 24 ++++ python/pylibraft/pylibraft/common/mdspan.pyx | 67 ++++++++++- .../pylibraft/neighbors/cagra/cagra.pyx | 106 ++++++++++++++++-- .../pylibraft/neighbors/cagra/cpp/c_cagra.pxd | 30 ++++- python/pylibraft/pylibraft/test/test_cagra.py | 10 +- 10 files changed, 284 insertions(+), 52 deletions(-) diff --git a/cpp/include/raft/neighbors/cagra_serialize.cuh b/cpp/include/raft/neighbors/cagra_serialize.cuh index 2242629409..0a806402d2 100644 --- a/cpp/include/raft/neighbors/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/cagra_serialize.cuh @@ -47,12 +47,16 @@ namespace raft::neighbors::cagra { * @param[in] handle the raft handle * @param[in] os output stream * @param[in] index CAGRA index + * @param[in] include_dataset Whether or not to write out the dataset to the file. * */ template -void serialize(raft::resources const& handle, std::ostream& os, const index& index) +void serialize(raft::resources const& handle, + std::ostream& os, + const index& index, + bool include_dataset = true) { - detail::serialize(handle, os, index); + detail::serialize(handle, os, index, include_dataset); } /** @@ -77,14 +81,16 @@ void serialize(raft::resources const& handle, std::ostream& os, const index void serialize(raft::resources const& handle, const std::string& filename, - const index& index) + const index& index, + bool include_dataset = true) { - detail::serialize(handle, filename, index); + detail::serialize(handle, filename, index, include_dataset); } /** @@ -158,4 +164,4 @@ namespace raft::neighbors::experimental::cagra { using raft::neighbors::cagra::deserialize; using raft::neighbors::cagra::serialize; -} // namespace raft::neighbors::experimental::cagra \ No newline at end of file +} // namespace raft::neighbors::experimental::cagra diff --git a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh index 8d040c352b..2c9cbd2563 100644 --- a/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh +++ b/cpp/include/raft/neighbors/detail/cagra/cagra_serialize.cuh @@ -24,8 +24,7 @@ namespace raft::neighbors::cagra::detail { -// Serialization version 1. -constexpr int serialization_version = 2; +constexpr int serialization_version = 3; // NB: we wrap this check in a struct, so that the updated RealSize is easy to see in the error // message. @@ -50,41 +49,53 @@ template struct check_index_layout), expecte * */ template -void serialize(raft::resources const& res, std::ostream& os, const index& index_) +void serialize(raft::resources const& res, + std::ostream& os, + const index& index_, + bool include_dataset) { RAFT_LOG_DEBUG( "Saving CAGRA index, size %zu, dim %u", static_cast(index_.size()), index_.dim()); + std::string dtype_string = raft::detail::numpy_serializer::get_numpy_dtype().to_string(); + dtype_string.resize(4); + os << dtype_string; + serialize_scalar(res, os, serialization_version); serialize_scalar(res, os, index_.size()); serialize_scalar(res, os, index_.dim()); serialize_scalar(res, os, index_.graph_degree()); serialize_scalar(res, os, index_.metric()); - auto dataset = index_.dataset(); - // Remove padding before saving the dataset - auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); - RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), - sizeof(T) * host_dataset.extent(1), - dataset.data_handle(), - sizeof(T) * dataset.stride(0), - sizeof(T) * host_dataset.extent(1), - dataset.extent(0), - cudaMemcpyDefault, - resource::get_cuda_stream(res))); - resource::sync_stream(res); - serialize_mdspan(res, os, host_dataset.view()); serialize_mdspan(res, os, index_.graph()); + + serialize_scalar(res, os, include_dataset); + if (include_dataset) { + auto dataset = index_.dataset(); + // Remove padding before saving the dataset + auto host_dataset = make_host_matrix(dataset.extent(0), dataset.extent(1)); + RAFT_CUDA_TRY(cudaMemcpy2DAsync(host_dataset.data_handle(), + sizeof(T) * host_dataset.extent(1), + dataset.data_handle(), + sizeof(T) * dataset.stride(0), + sizeof(T) * host_dataset.extent(1), + dataset.extent(0), + cudaMemcpyDefault, + resource::get_cuda_stream(res))); + resource::sync_stream(res); + serialize_mdspan(res, os, host_dataset.view()); + } } template void serialize(raft::resources const& res, const std::string& filename, - const index& index_) + const index& index_, + bool include_dataset) { std::ofstream of(filename, std::ios::out | std::ios::binary); if (!of) { RAFT_FAIL("Cannot open file %s", filename.c_str()); } - detail::serialize(res, of, index_); + detail::serialize(res, of, index_, include_dataset); of.close(); if (!of) { RAFT_FAIL("Error writing output %s", filename.c_str()); } @@ -102,6 +113,9 @@ void serialize(raft::resources const& res, template auto deserialize(raft::resources const& res, std::istream& is) -> index { + char dtype_string[4]; + is.read(dtype_string, 4); + auto ver = deserialize_scalar(res, is); if (ver != serialization_version) { RAFT_FAIL("serialization version mismatch, expected %d, got %d ", serialization_version, ver); @@ -113,9 +127,11 @@ auto deserialize(raft::resources const& res, std::istream& is) -> index auto dataset = raft::make_host_matrix(n_rows, dim); auto graph = raft::make_host_matrix(n_rows, graph_degree); - deserialize_mdspan(res, is, dataset.view()); deserialize_mdspan(res, is, graph.view()); + bool has_dataset = deserialize_scalar(res, is); + if (has_dataset) { deserialize_mdspan(res, is, dataset.view()); } + return index( res, metric, raft::make_const_mdspan(dataset.view()), raft::make_const_mdspan(graph.view())); } diff --git a/cpp/include/raft_runtime/neighbors/cagra.hpp b/cpp/include/raft_runtime/neighbors/cagra.hpp index 6f56302776..c54ed32b77 100644 --- a/cpp/include/raft_runtime/neighbors/cagra.hpp +++ b/cpp/include/raft_runtime/neighbors/cagra.hpp @@ -56,14 +56,16 @@ namespace raft::runtime::neighbors::cagra { raft::device_matrix_view distances); \ void serialize_file(raft::resources const& handle, \ const std::string& filename, \ - const raft::neighbors::cagra::index& index); \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ \ void deserialize_file(raft::resources const& handle, \ const std::string& filename, \ raft::neighbors::cagra::index* index); \ void serialize(raft::resources const& handle, \ std::string& str, \ - const raft::neighbors::cagra::index& index); \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset = true); \ \ void deserialize(raft::resources const& handle, \ const std::string& str, \ diff --git a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu index be9788562a..69b48b93a4 100644 --- a/cpp/src/raft_runtime/neighbors/cagra_serialize.cu +++ b/cpp/src/raft_runtime/neighbors/cagra_serialize.cu @@ -27,9 +27,10 @@ namespace raft::runtime::neighbors::cagra { #define RAFT_INST_CAGRA_SERIALIZE(DTYPE) \ void serialize_file(raft::resources const& handle, \ const std::string& filename, \ - const raft::neighbors::cagra::index& index) \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ { \ - raft::neighbors::cagra::serialize(handle, filename, index); \ + raft::neighbors::cagra::serialize(handle, filename, index, include_dataset); \ }; \ \ void deserialize_file(raft::resources const& handle, \ @@ -41,10 +42,11 @@ namespace raft::runtime::neighbors::cagra { }; \ void serialize(raft::resources const& handle, \ std::string& str, \ - const raft::neighbors::cagra::index& index) \ + const raft::neighbors::cagra::index& index, \ + bool include_dataset) \ { \ std::stringstream os; \ - raft::neighbors::cagra::serialize(handle, os, index); \ + raft::neighbors::cagra::serialize(handle, os, index, include_dataset); \ str = os.str(); \ } \ \ diff --git a/cpp/test/neighbors/ann_cagra.cuh b/cpp/test/neighbors/ann_cagra.cuh index 89cb070afc..ea905d2089 100644 --- a/cpp/test/neighbors/ann_cagra.cuh +++ b/cpp/test/neighbors/ann_cagra.cuh @@ -137,6 +137,7 @@ struct AnnCagraInputs { int search_width; raft::distance::DistanceType metric; bool host_dataset; + bool include_serialized_dataset; // std::optional double min_recall; // = std::nullopt; }; @@ -217,9 +218,11 @@ class AnnCagraTest : public ::testing::TestWithParam { } else { index = cagra::build(handle_, index_params, database_view); }; - cagra::serialize(handle_, "cagra_index", index); + cagra::serialize(handle_, "cagra_index", index, ps.include_serialized_dataset); } + auto index = cagra::deserialize(handle_, "cagra_index"); + if (!ps.include_serialized_dataset) { index.update_dataset(handle_, database_view); } auto search_queries_view = raft::make_device_matrix_view( search_queries.data(), ps.n_queries, ps.dim); @@ -340,9 +343,7 @@ class AnnCagraSortTest : public ::testing::TestWithParam { void SetUp() override { - std::cout << "Resizing database: " << ps.n_rows * ps.dim << std::endl; database.resize(((size_t)ps.n_rows) * ps.dim, handle_.get_stream()); - std::cout << "Done.\nRuning rng" << std::endl; raft::random::Rng r(1234ULL); if constexpr (std::is_same{}) { GenerateRoundingErrorFreeDataset(database.data(), ps.n_rows, ps.dim, r, handle_.get_stream()); @@ -379,6 +380,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false}, + {true}, {0.995}); auto inputs2 = raft::util::itertools::product( @@ -393,6 +395,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false}, + {true}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); inputs2 = @@ -407,6 +410,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false}, + {false}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); @@ -422,6 +426,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false}, + {true}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); @@ -437,6 +442,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false, true}, + {false}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); @@ -452,6 +458,7 @@ inline std::vector generate_inputs() {1}, {raft::distance::DistanceType::L2Expanded}, {false, true}, + {true}, {0.995}); inputs.insert(inputs.end(), inputs2.begin(), inputs2.end()); diff --git a/python/pylibraft/pylibraft/common/mdspan.pxd b/python/pylibraft/pylibraft/common/mdspan.pxd index 6b202c2b69..17dd2d8bfd 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pxd +++ b/python/pylibraft/pylibraft/common/mdspan.pxd @@ -30,6 +30,12 @@ from pylibraft.common.cpp.mdspan cimport ( from pylibraft.common.handle cimport device_resources from pylibraft.common.optional cimport make_optional, optional +# Cython doesn't like `const float` inside template parameters +# hack around this with using typedefs +ctypedef const float const_float +ctypedef const int8_t const_int8_t +ctypedef const uint8_t const_uint8_t + cdef device_matrix_view[float, int64_t, row_major] get_dmv_float( array, check_shape) except * @@ -49,6 +55,15 @@ cdef optional[device_matrix_view[int64_t, int64_t, row_major]] make_optional_vie cdef device_matrix_view[uint32_t, int64_t, row_major] get_dmv_uint32( array, check_shape) except * +cdef device_matrix_view[const_float, int64_t, row_major] get_const_dmv_float( + array, check_shape) except * + +cdef device_matrix_view[const_uint8_t, int64_t, row_major] get_const_dmv_uint8( + array, check_shape) except * + +cdef device_matrix_view[const_int8_t, int64_t, row_major] get_const_dmv_int8( + array, check_shape) except * + cdef host_matrix_view[float, int64_t, row_major] get_hmv_float( array, check_shape) except * @@ -63,3 +78,12 @@ cdef host_matrix_view[int64_t, int64_t, row_major] get_hmv_int64( cdef host_matrix_view[uint32_t, int64_t, row_major] get_hmv_uint32( array, check_shape) except * + +cdef host_matrix_view[const_float, int64_t, row_major] get_const_hmv_float( + array, check_shape) except * + +cdef host_matrix_view[const_uint8_t, int64_t, row_major] get_const_hmv_uint8( + array, check_shape) except * + +cdef host_matrix_view[const_int8_t, int64_t, row_major] get_const_hmv_int8( + array, check_shape) except * diff --git a/python/pylibraft/pylibraft/common/mdspan.pyx b/python/pylibraft/pylibraft/common/mdspan.pyx index 1219b1612d..7442a6bb89 100644 --- a/python/pylibraft/pylibraft/common/mdspan.pyx +++ b/python/pylibraft/pylibraft/common/mdspan.pyx @@ -193,6 +193,39 @@ cdef device_matrix_view[int64_t, int64_t, row_major] \ cai.data, shape[0], shape[1]) +cdef device_matrix_view[const_float, int64_t, row_major] \ + get_const_dmv_float(cai, check_shape) except *: + if cai.dtype != np.float32: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[const_float, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[const_uint8_t, int64_t, row_major] \ + get_const_dmv_uint8(cai, check_shape) except *: + if cai.dtype != np.uint8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[const_uint8_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef device_matrix_view[const_int8_t, int64_t, row_major] \ + get_const_dmv_int8(cai, check_shape) except *: + if cai.dtype != np.int8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_device_matrix_view[const_int8_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + cdef optional[device_matrix_view[int64_t, int64_t, row_major]] \ make_optional_view_int64(device_matrix_view[int64_t, int64_t, row_major]& dmv) except *: # noqa: E501 return make_optional[device_matrix_view[int64_t, int64_t, row_major]](dmv) @@ -222,7 +255,6 @@ cdef host_matrix_view[float, int64_t, row_major] \ return make_host_matrix_view[float, int64_t, row_major]( cai.data, shape[0], shape[1]) - cdef host_matrix_view[uint8_t, int64_t, row_major] \ get_hmv_uint8(cai, check_shape) except *: if cai.dtype != np.uint8: @@ -265,3 +297,36 @@ cdef host_matrix_view[uint32_t, int64_t, row_major] \ shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) return make_host_matrix_view[uint32_t, int64_t, row_major]( cai.data, shape[0], shape[1]) + + +cdef host_matrix_view[const_float, int64_t, row_major] \ + get_const_hmv_float(cai, check_shape) except *: + if cai.dtype != np.float32: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[const_float, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef host_matrix_view[const_uint8_t, int64_t, row_major] \ + get_const_hmv_uint8(cai, check_shape) except *: + if cai.dtype != np.uint8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[const_uint8_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) + + +cdef host_matrix_view[const_int8_t, int64_t, row_major] \ + get_const_hmv_int8(cai, check_shape) except *: + if cai.dtype != np.int8: + raise TypeError("dtype %s not supported" % cai.dtype) + if check_shape and len(cai.shape) != 2: + raise ValueError("Expected a 2D array, got %d D" % len(cai.shape)) + shape = (cai.shape[0], cai.shape[1] if len(cai.shape) == 2 else 1) + return make_host_matrix_view[const_int8_t, int64_t, row_major]( + cai.data, shape[0], shape[1]) diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx index fbc1623cac..e0c59a5ed3 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx +++ b/python/pylibraft/pylibraft/neighbors/cagra/cagra.pyx @@ -69,6 +69,12 @@ from pylibraft.common.cpp.mdspan cimport ( row_major, ) from pylibraft.common.mdspan cimport ( + get_const_dmv_float, + get_const_dmv_int8, + get_const_dmv_uint8, + get_const_hmv_float, + get_const_hmv_int8, + get_const_hmv_uint8, get_dmv_float, get_dmv_int8, get_dmv_int64, @@ -162,6 +168,31 @@ cdef class IndexFloat(Index): attr_str = [m_str] + attr_str return "Index(type=CAGRA, " + (", ".join(attr_str)) + ")" + @auto_sync_handle + def update_dataset(self, dataset, handle=None): + """ Replace the dataset with a new dataset. + + Parameters + ---------- + dataset : array interface compliant matrix shape (n_samples, dim) + {handle_docstring} + """ + cdef device_resources* handle_ = \ + handle.getHandle() + + dataset_ai = wrap_array(dataset) + dataset_dt = dataset_ai.dtype + _check_input_array(dataset_ai, [np.dtype("float32")]) + + if dataset_ai.from_cai: + self.index[0].update_dataset(deref(handle_), + get_const_dmv_float(dataset_ai, + check_shape=True)) + else: + self.index[0].update_dataset(deref(handle_), + get_const_hmv_float(dataset_ai, + check_shape=True)) + @property def metric(self): return self.index[0].metric() @@ -195,6 +226,31 @@ cdef class IndexInt8(Index): self.index = new c_cagra.index[int8_t, uint32_t]( deref(handle_)) + @auto_sync_handle + def update_dataset(self, dataset, handle=None): + """ Replace the dataset with a new dataset. + + Parameters + ---------- + dataset : array interface compliant matrix shape (n_samples, dim) + {handle_docstring} + """ + cdef device_resources* handle_ = \ + handle.getHandle() + + dataset_ai = wrap_array(dataset) + dataset_dt = dataset_ai.dtype + _check_input_array(dataset_ai, [np.dtype("byte")]) + + if dataset_ai.from_cai: + self.index[0].update_dataset(deref(handle_), + get_const_dmv_int8(dataset_ai, + check_shape=True)) + else: + self.index[0].update_dataset(deref(handle_), + get_const_hmv_int8(dataset_ai, + check_shape=True)) + def __repr__(self): m_str = "metric=" + _get_metric_string(self.index.metric()) attr_str = [attr + "=" + str(getattr(self, attr)) @@ -235,6 +291,31 @@ cdef class IndexUint8(Index): self.index = new c_cagra.index[uint8_t, uint32_t]( deref(handle_)) + @auto_sync_handle + def update_dataset(self, dataset, handle=None): + """ Replace the dataset with a new dataset. + + Parameters + ---------- + dataset : array interface compliant matrix shape (n_samples, dim) + {handle_docstring} + """ + cdef device_resources* handle_ = \ + handle.getHandle() + + dataset_ai = wrap_array(dataset) + dataset_dt = dataset_ai.dtype + _check_input_array(dataset_ai, [np.dtype("ubyte")]) + + if dataset_ai.from_cai: + self.index[0].update_dataset(deref(handle_), + get_const_dmv_uint8(dataset_ai, + check_shape=True)) + else: + self.index[0].update_dataset(deref(handle_), + get_const_hmv_uint8(dataset_ai, + check_shape=True)) + def __repr__(self): m_str = "metric=" + _get_metric_string(self.index.metric()) attr_str = [attr + "=" + str(getattr(self, attr)) @@ -693,7 +774,7 @@ def search(SearchParams search_params, @auto_sync_handle -def save(filename, Index index, handle=None): +def save(filename, Index index, bool include_dataset=True, handle=None): """ Saves the index to a file. @@ -706,6 +787,12 @@ def save(filename, Index index, handle=None): Name of the file. index : Index Trained CAGRA index. + include_dataset : bool + Whether or not to write out the dataset along with the index. Including + the dataset in the serialized index will use extra disk space, and + might not be desired if you already have a copy of the dataset on + disk. If this option is set to false, you will have to call + `index.update_dataset(dataset)` after loading the index. {handle_docstring} Examples @@ -741,15 +828,17 @@ def save(filename, Index index, handle=None): if index.active_index_type == "float32": idx_float = index c_cagra.serialize_file( - deref(handle_), c_filename, deref(idx_float.index)) + deref(handle_), c_filename, deref(idx_float.index), + include_dataset) elif index.active_index_type == "byte": idx_int8 = index c_cagra.serialize_file( - deref(handle_), c_filename, deref(idx_int8.index)) + deref(handle_), c_filename, deref(idx_int8.index), include_dataset) elif index.active_index_type == "ubyte": idx_uint8 = index c_cagra.serialize_file( - deref(handle_), c_filename, deref(idx_uint8.index)) + deref(handle_), c_filename, deref(idx_uint8.index), + include_dataset) else: raise ValueError( "Index dtype %s not supported" % index.active_index_type) @@ -785,12 +874,9 @@ def load(filename, handle=None): cdef IndexInt8 idx_int8 cdef IndexUint8 idx_uint8 - # we extract the dtype from the array interfaces in the file - with open(filename, 'rb') as f: - type_str = f.read(700).decode("utf-8", errors='ignore') - - # Read description of the 6th element to get the datatype - dataset_dt = np.dtype(type_str.split('descr')[6][5:7]) + with open(filename, "rb") as f: + type_str = f.read(3).decode("utf8") + dataset_dt = np.dtype(type_str) if dataset_dt == np.float32: idx_float = IndexFloat(handle) diff --git a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd index 284c75b771..0c683bcd9b 100644 --- a/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd +++ b/python/pylibraft/pylibraft/neighbors/cagra/cpp/c_cagra.pxd @@ -36,6 +36,7 @@ from pylibraft.common.cpp.mdspan cimport ( row_major, ) from pylibraft.common.handle cimport device_resources +from pylibraft.common.mdspan cimport const_float, const_int8_t, const_uint8_t from pylibraft.common.optional cimport optional from pylibraft.distance.distance_type cimport DistanceType from pylibraft.neighbors.ivf_pq.cpp.c_ivf_pq cimport ( @@ -90,6 +91,17 @@ cdef extern from "raft/neighbors/cagra_types.hpp" \ device_matrix_view[T, IdxT, row_major] dataset() device_matrix_view[T, IdxT, row_major] graph() + # hack: can't use the T template param here because of issues handling + # const w/ cython. introduce a new template param to get around this + void update_dataset[ValueT](const device_resources & handle, + host_matrix_view[ValueT, + int64_t, + row_major] dataset) + void update_dataset[ValueT](const device_resources & handle, + device_matrix_view[ValueT, + int64_t, + row_major] dataset) + cdef extern from "raft_runtime/neighbors/cagra.hpp" \ namespace "raft::runtime::neighbors::cagra" nogil: @@ -155,7 +167,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize(const device_resources& handle, string& str, - const index[float, uint32_t]& index) except + + const index[float, uint32_t]& index, + bool include_dataset) except + cdef void deserialize(const device_resources& handle, const string& str, @@ -163,7 +176,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize(const device_resources& handle, string& str, - const index[uint8_t, uint32_t]& index) except + + const index[uint8_t, uint32_t]& index, + bool include_dataset) except + cdef void deserialize(const device_resources& handle, const string& str, @@ -171,7 +185,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize(const device_resources& handle, string& str, - const index[int8_t, uint32_t]& index) except + + const index[int8_t, uint32_t]& index, + bool include_dataset) except + cdef void deserialize(const device_resources& handle, const string& str, @@ -179,7 +194,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize_file(const device_resources& handle, const string& filename, - const index[float, uint32_t]& index) except + + const index[float, uint32_t]& index, + bool include_dataset) except + cdef void deserialize_file(const device_resources& handle, const string& filename, @@ -187,7 +203,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize_file(const device_resources& handle, const string& filename, - const index[uint8_t, uint32_t]& index) except + + const index[uint8_t, uint32_t]& index, + bool include_dataset) except + cdef void deserialize_file(const device_resources& handle, const string& filename, @@ -195,7 +212,8 @@ cdef extern from "raft_runtime/neighbors/cagra.hpp" \ cdef void serialize_file(const device_resources& handle, const string& filename, - const index[int8_t, uint32_t]& index) except + + const index[int8_t, uint32_t]& index, + bool include_dataset) except + cdef void deserialize_file(const device_resources& handle, const string& filename, diff --git a/python/pylibraft/pylibraft/test/test_cagra.py b/python/pylibraft/pylibraft/test/test_cagra.py index 435b2878a2..74e9f53b91 100644 --- a/python/pylibraft/pylibraft/test/test_cagra.py +++ b/python/pylibraft/pylibraft/test/test_cagra.py @@ -255,7 +255,8 @@ def test_cagra_search_params(params): @pytest.mark.parametrize("dtype", [np.float32, np.int8, np.ubyte]) -def test_save_load(dtype): +@pytest.mark.parametrize("include_dataset", [True, False]) +def test_save_load(dtype, include_dataset): n_rows = 10000 n_cols = 50 n_queries = 1000 @@ -268,9 +269,14 @@ def test_save_load(dtype): assert index.trained filename = "my_index.bin" - cagra.save(filename, index) + cagra.save(filename, index, include_dataset=include_dataset) loaded_index = cagra.load(filename) + # if we didn't save the dataset with the index, we need to update the + # index with an already loaded copy + if not include_dataset: + loaded_index.update_dataset(dataset) + queries = generate_data((n_queries, n_cols), dtype) queries_device = device_ndarray(queries)